Example #1
0
    def __init__(self, config):
        self.logger = ModelLogger(config, dirname=config['dir'])
        self.dirname = self.logger.dirname
        cuda = config['cuda']
        cuda_id = config['cuda_id']
        if not cuda:
            self.device = torch.device('cpu')
        else:
            self.device = torch.device('cuda' if cuda_id < 0 else 'cuda:%d' %
                                       cuda_id)

        datasets = prepare_datasets(config)
        train_set = datasets['train']
        dev_set = datasets['dev']
        test_set = datasets['test']

        # Evaluation Metrics:
        self._train_loss = AverageMeter()
        self._train_f1 = AverageMeter()
        self._train_em = AverageMeter()
        self._dev_f1 = AverageMeter()
        self._dev_em = AverageMeter()

        if train_set:
            self.train_loader = DataLoader(train_set,
                                           batch_size=config['batch_size'],
                                           shuffle=config['shuffle'],
                                           collate_fn=lambda x: x,
                                           pin_memory=True)
            self._n_train_batches = len(train_set) // config['batch_size']
        else:
            self.train_loader = None

        if dev_set:
            self.dev_loader = DataLoader(dev_set,
                                         batch_size=config['batch_size'],
                                         shuffle=False,
                                         collate_fn=lambda x: x,
                                         pin_memory=True)
            self._n_dev_batches = len(dev_set) // config['batch_size']
        else:
            self.dev_loader = None

        if test_set:
            self.test_loader = DataLoader(test_set,
                                          batch_size=config['batch_size'],
                                          shuffle=False,
                                          collate_fn=lambda x: x,
                                          pin_memory=True)
            self._n_test_batches = len(test_set) // config['batch_size']
            self._n_test_examples = len(test_set)
        else:
            self.test_loader = None

        self._n_train_examples = 0
        self.model = Model(config, train_set)
        self.tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
        self.model.network = self.model.network.to(self.device)
        self.config = config
        self.is_test = False
Example #2
0
    def validation(self):
        self.model.eval()
        self.evaluator_1.reset()
        self.evaluator_2.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        time_meter_1 = AverageMeter()
        time_meter_2 = AverageMeter()

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output_1, output_2 = self.model(image)

            loss_1 = self.criterion(output_1, target)
            loss_2 = self.criterion(output_2, target)

            pred_1 = torch.argmax(output_1, axis=1)
            pred_2 = torch.argmax(output_2, axis=1)

            # Add batch sample into evaluator
            self.evaluator_1.add_batch(target, pred_1)
            self.evaluator_2.add_batch(target, pred_2)

        mIoU_1 = self.evaluator_1.Mean_Intersection_over_Union()
        mIoU_2 = self.evaluator_2.Mean_Intersection_over_Union()

        print('Validation:')
        print("mIoU_1:{}, mIoU_2: {}".format(mIoU_1, mIoU_2))
Example #3
0
    def __init__(self, config):

        self.config = config
        tokenizer_model = MODELS[config['model_name']]

        self.train_loader, self.dev_loader, tokenizer = prepare_datasets(
            config, tokenizer_model)

        self._n_dev_batches = len(
            self.dev_loader.dataset) // config['batch_size']
        self._n_train_batches = len(
            self.train_loader.dataset) // config['batch_size']
        if config['cuda']:

            self.device = torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu')

        else:
            self.device = torch.device('cpu')
        print("use device: ", self.device)

        self._train_loss = AverageMeter()
        self._train_f1 = AverageMeter()
        self._train_em = AverageMeter()
        self._dev_f1 = AverageMeter()
        self._dev_em = AverageMeter()

        self.model = Model(config, MODELS[config['model_name']], self.device,
                           tokenizer).to(self.device)
        t_total = len(
            self.train_loader
        ) // config['gradient_accumulation_steps'] * config['max_epochs']
        self.optimizer = AdamW(self.model.parameters(),
                               lr=config['lr'],
                               eps=config['adam_epsilon'])
        self.optimizer.zero_grad()
        self._n_train_examples = 0
        self._epoch = self._best_epoch = 0
        self._best_f1 = 0
        self._best_em = 0
        self.restored = False
        if config['pretrained_dir'] is not None:
            if config['mode'] == 'train':
                self.restore()
            else:
                self.load_model()
    def time_measure(self):
        time_meter_1 = AverageMeter()
        time_meter_2 = AverageMeter()
        self.model.eval()
        self.evaluator_1.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                _, _, t1, t2 = self.model.time_measure(image)
            if t1 != None:
                time_meter_1.update(t1)
            time_meter_2.update(t2)
        if t1 != None:
            print(time_meter_1.average())
        print(time_meter_2.average())
    def dynamic_inference(self, threshold, confidence):
        self.model.eval()
        self.evaluator_1.reset()
        time_meter = AverageMeter()
        if confidence == 'edm':
            self.edm.eval()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        total_earlier_exit = 0
        confidence_value_avg = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output, earlier_exit, tic, confidence_value = \
                self.model.dynamic_inference(image, threshold=threshold, confidence=confidence, edm=self.edm)
            total_earlier_exit += earlier_exit
            confidence_value_avg += confidence_value
            time_meter.update(tic)

            loss = self.criterion(output, target)
            pred = torch.argmax(output, axis=1)

            # Add batch sample into evaluator
            self.evaluator_1.add_batch(target, pred)
            tbar.set_description('earlier_exit_num: %.1f' %
                                 (total_earlier_exit))
        mIoU = self.evaluator_1.Mean_Intersection_over_Union()

        print('Validation:')
        print("mIoU: {}".format(mIoU))
        print("mean_inference_time: {}".format(time_meter.average()))
        print("fps: {}".format(1.0 / time_meter.average()))
        print("num_earlier_exit: {}".format(total_earlier_exit / 500 * 100))
        print("avg_confidence: {}".format(confidence_value_avg / 500))
Example #6
0
class ModelHandler():
    def __init__(self, config):

        self.config = config
        tokenizer_model = MODELS[config['model_name']]

        self.train_loader, self.dev_loader, tokenizer = prepare_datasets(
            config, tokenizer_model)

        self._n_dev_batches = len(
            self.dev_loader.dataset) // config['batch_size']
        self._n_train_batches = len(
            self.train_loader.dataset) // config['batch_size']
        if config['cuda']:

            self.device = torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu')

        else:
            self.device = torch.device('cpu')
        print("use device: ", self.device)

        self._train_loss = AverageMeter()
        self._train_f1 = AverageMeter()
        self._train_em = AverageMeter()
        self._dev_f1 = AverageMeter()
        self._dev_em = AverageMeter()

        self.model = Model(config, MODELS[config['model_name']], self.device,
                           tokenizer).to(self.device)
        t_total = len(
            self.train_loader
        ) // config['gradient_accumulation_steps'] * config['max_epochs']
        self.optimizer = AdamW(self.model.parameters(),
                               lr=config['lr'],
                               eps=config['adam_epsilon'])
        self.optimizer.zero_grad()
        self._n_train_examples = 0
        self._epoch = self._best_epoch = 0
        self._best_f1 = 0
        self._best_em = 0
        self.restored = False
        if config['pretrained_dir'] is not None:
            if config['mode'] == 'train':
                self.restore()
            else:
                self.load_model()

    def train(self):
        timer = Timer(' timer')

        if not self.restored:
            print("\n>>> Dev Epoch: [{} / {}]".format(
                self._epoch, self.config['max_epochs']))
            self._run_epoch(self.dev_loader,
                            training=False,
                            verbose=self.config['verbose'],
                            save=False)

            format_str = "Validation Epoch {} -- F1: {:0.2f}, EM: {:0.2f} --"
            print(
                format_str.format(self._epoch, self._dev_f1.mean(),
                                  self._dev_em.mean()))
            self._best_f1 = self._dev_f1.mean()
            self._best_em = self._dev_em.mean()
        while self._stop_condition(self._epoch):
            self._epoch += 1
            print("\n>>> Train Epoch: [{} / {}]".format(
                self._epoch, self.config['max_epochs']))
            if not self.restored:
                self.train_loader.prepare()
            self.restored = False

            self._run_epoch(self.train_loader,
                            training=True,
                            verbose=self.config['verbose'])
            format_str = "Training Epoch {} -- Loss: {:0.4f}, F1: {:0.2f}, EM: {:0.2f} --"
            print(
                format_str.format(self._epoch, self._train_loss.mean(),
                                  self._train_f1.mean(),
                                  self._train_em.mean()))
            print("\n>>> Dev Epoch: [{} / {}]".format(
                self._epoch, self.config['max_epochs']))
            self.dev_loader.prepare()
            self._run_epoch(self.dev_loader,
                            training=False,
                            verbose=self.config['verbose'],
                            save=False)
            format_str = "Validation Epoch {} -- F1: {:0.2f}, EM: {:0.2f} --"
            print(
                format_str.format(self._epoch, self._dev_f1.mean(),
                                  self._dev_em.mean()))
            print("has finish :{} epoch, remaining time:{}".format(
                self._epoch,
                timer.remains(self.config['max_epochs'], self._epoch)))

            if self._best_f1 <= self._dev_f1.mean():
                self._best_epoch = self._epoch
                self._best_f1 = self._dev_f1.mean()
                self._best_em = self._dev_em.mean()
                print("!!! Updated: F1: {:0.2f}, EM: {:0.2f}".format(
                    self._best_f1, self._best_em))
            self._reset_metrics()
            self.save(self._epoch)

    def load_model(self):
        restored_params = torch.load(self.config['pretrained_dir'] +
                                     '/best/model.pth')
        self.model.load_state_dict(restored_params['model'])

    def restore(self):
        if not os.path.exists(self.config['pretrained_dir']):
            print('dir doesn\'t exists, cannot restore')
            return
        restored_params = torch.load(self.config['pretrained_dir'] +
                                     '/latest/model.pth')
        self.model.load_state_dict(restored_params['model'])
        self.optimizer.load_state_dict(restored_params['optimizer'])
        self._epoch = restored_params['epoch']
        self._best_epoch = restored_params['best_epoch']
        self._n_train_examples = restored_params['train_examples']
        self._best_f1 = restored_params['best_f1']
        self._best_em = restored_params['best_em']
        examples = restored_params['dataloader_examples']
        batch_state = restored_params['dataloader_batch_state']
        state = restored_params['dataloader_state']
        self.train_loader.restore(examples, state, batch_state)

        self.restored = True

    def save(self, save_epoch_val):
        if not os.path.exists(self.config['save_state_dir']):
            os.mkdir(self.config['save_state_dir'])

        if self._best_epoch == self._epoch:
            if not os.path.exists(self.config['save_state_dir'] + '/best'):
                os.mkdir(self.config['save_state_dir'] + '/best')
            save_dic = {
                'epoch': self._epoch,
                'best_epoch': self._best_epoch,
                'train_examples': self._n_train_examples,
                'model': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_f1': self._best_f1,
                'best_em': self._best_em
            }
            torch.save(save_dic,
                       self.config['save_state_dir'] + '/best/model.pth')

        if not os.path.exists(self.config['save_state_dir'] + '/latest'):
            os.mkdir(self.config['save_state_dir'] + '/latest')
        save_dic = {
            'epoch': save_epoch_val,
            'best_epoch': self._best_epoch,
            'train_examples': self._n_train_examples,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'best_f1': self._best_f1,
            'best_em': self._best_em,
            'dataloader_batch_state': self.train_loader.batch_state,
            'dataloader_state': self.train_loader.state,
            'dataloader_examples': self.train_loader.examples
        }
        torch.save(save_dic,
                   self.config['save_state_dir'] + '/latest/model.pth')

    def _run_epoch(self,
                   data_loader,
                   training=True,
                   verbose=10,
                   out_predictions=False,
                   save=True):
        start_time = time.time()
        while data_loader.batch_state < len(data_loader):
            input_batch = data_loader.get()
            res = self.model(input_batch, training)
            tr_loss = 0
            if training:
                loss = res['loss']
                if self.config['gradient_accumulation_steps'] > 1:
                    loss = loss / self.config['gradient_accumulation_steps']
                tr_loss = loss.mean().item()
            start_logits = res['start_logits']
            end_logits = res['end_logits']

            if training:
                self.model.update(loss, self.optimizer,
                                  data_loader.batch_state)
            paragraphs = [inp['tokens'] for inp in input_batch]
            answers = [inp['answer'] for inp in input_batch]
            # paragraph_id_list = [inp['paragraph_id'] for inp in input_batch]
            # turn_id_list = [inp['turn_id'] for inp in input_batch]
            # print("paragraph_id:{0},turn_id:{1}".format(paragraph_id_list[0],turn_id_list[0]))
            f1, em = self.model.evaluate(start_logits, end_logits, paragraphs,
                                         answers)

            self._update_metrics(tr_loss,
                                 f1,
                                 em,
                                 len(paragraphs),
                                 training=training)

            if training:
                self._n_train_examples += len(paragraphs)
            if (verbose > 0) and (data_loader.batch_state % verbose == 0):
                if save:
                    self.save(self._epoch - 1)
                mode = "train" if training else "dev"
                print(
                    self.report(data_loader.batch_state, tr_loss, f1 * 100,
                                em * 100, mode))
                print('used_time: {:0.2f}s'.format(time.time() - start_time))

    def _update_metrics(self, loss, f1, em, batch_size, training=True):
        if training:
            self._train_loss.update(loss)
            self._train_f1.update(f1 * 100, batch_size)
            self._train_em.update(em * 100, batch_size)
        else:
            self._dev_f1.update(f1 * 100, batch_size)
            self._dev_em.update(em * 100, batch_size)

    def _reset_metrics(self):
        self._train_loss.reset()
        self._train_f1.reset()
        self._train_em.reset()
        self._dev_f1.reset()
        self._dev_em.reset()

    def report(self, step, loss, f1, em, mode='train'):
        if mode == "train":
            format_str = "[train-{}] step: [{} / {}] | exs = {} | loss = {:0.4f} | f1 = {:0.2f} | em = {:0.2f}"
            return format_str.format(self._epoch, step, self._n_train_batches,
                                     self._n_train_examples, loss, f1, em)
        elif mode == "dev":
            return "[predict-{}] step: [{} / {}] | f1 = {:0.2f} | em = {:0.2f}".format(
                self._epoch, step, self._n_dev_batches, f1, em)
        elif mode == "test":
            return "[test] | test_exs = {} | step: [{} / {}] | f1 = {:0.2f} | em = {:0.2f}".format(
                self._n_test_examples, step, self._n_test_batches, f1, em)
        else:
            raise ValueError('mode = {} not supported.' % mode)

    def _stop_condition(self, epoch):
        """
		Checks have not exceeded max epochs and has not gone 10 epochs without improvement.
		"""
        no_improvement = epoch >= self._best_epoch + 10
        exceeded_max_epochs = epoch >= self.config['max_epochs']
        return False if exceeded_max_epochs or no_improvement else True

    def test(self):
        data_loader = self.dev_loader
        data_loader.batch_size = 1
        prediciton_dic_list = []
        cnt = 1
        last_paragraph_id = -1
        last_turn_id = -1
        answer_filename = 'data/answers.json'
        timer1 = Timer()
        while data_loader.batch_state < len(data_loader):
            # if cnt>3:
            # 	break
            if cnt % 2000 == 0:
                print(timer1.remains(len(data_loader), cnt))
            input_batch = data_loader.get()
            prediction = self.gen_prediction(input_batch)
            turn_id = gen_turn_id(input_batch)
            paragraph_id = gen_paragraph_id(input_batch)
            prediction_dict = {
                "id": paragraph_id[0],
                "turn_id": turn_id[0],
                "answer": prediction[0]
            }

            is_exist, last_paragraph_id, last_turn_id = check_exist_status(
                paragraph_id, turn_id, last_paragraph_id, last_turn_id)
            if not is_exist:
                prediciton_dic_list.append(prediction_dict)
                cnt += 1

        with open(answer_filename, 'w') as outfile:
            json.dump(prediciton_dic_list, outfile)
        test_evaluator.test('data/coqa-dev-v1.0.json', answer_filename)
        print("generate {} answers".format(cnt - 1))

    def gen_prediction(self, input_batch):
        res = self.model(input_batch, False)
        start_logits = res['start_logits']
        end_logits = res['end_logits']
        paragraphs = [inp['tokens'] for inp in input_batch]
        predictions = self.model.gen_prediction(start_logits, end_logits,
                                                paragraphs)
        return predictions
Example #7
0
    def __init__(self, config):
        self.logger = ModelLogger(config, dirname=config['dir'], pretrained=config['pretrained'])
        self.dirname = self.logger.dirname
        cuda = config['cuda']
        cuda_id = config['cuda_id']
        if not cuda:
            self.device = torch.device('cpu')
        else:
            self.device = torch.device('cuda' if cuda_id < 0 else 'cuda:%d' % cuda_id)
        print("preparing datasets...")
        datasets = None # prepare_datasets(config)
        print(config["pretrained"])

        datasets = None

        if config["pretrained"] == "rc_models_20":
            # baseline
            with open("data/coqa_baseline_data.pkl", "rb") as f_in:
                datasets = pickle.load(f_in)

        elif config["pretrained"] == "rc_models_modified_20":
            # samo zgodovina
            with open("data/coqa_history_data.pkl", "rb") as f_in:
                datasets = pickle.load(f_in)

        elif config["pretrained"] == "rc_models_modified_full_20":
            # trenutno vprasanje in odgovor
            with open("data/coqa_full_data.pkl", "rb") as f_in:
                datasets = pickle.load(f_in)

        elif config["pretrained"] == "rc_models_modified_full_noA_20":
            # trenutno vprasanje
            with open("data/coqa_full_noA_data.pkl", "rb") as f_in:
                datasets = pickle.load(f_in)

        else:
            print("not a valid pretrained model")
            exit()


        train_set = None #datasets['train']
        dev_set = None #datasets['dev']
        test_set = datasets['dev']
        print("datasets prepared")

        # print(train_set.examples[:5])
        exit()

        # Evaluation Metrics:
        self._train_loss = AverageMeter()
        self._train_f1 = AverageMeter()
        self._train_em = AverageMeter()
        self._dev_f1 = AverageMeter()
        self._dev_em = AverageMeter()

        if train_set:
            self.train_loader = DataLoader(train_set, batch_size=config['batch_size'],
                                           shuffle=config['shuffle'], collate_fn=lambda x: x, pin_memory=True)
            self._n_train_batches = len(train_set) // config['batch_size']
        else:
            self.train_loader = None

        if dev_set:
            self.dev_loader = DataLoader(dev_set, batch_size=config['batch_size'],
                                         shuffle=False, collate_fn=lambda x: x, pin_memory=True)
            self._n_dev_batches = len(dev_set) // config['batch_size']
        else:
            self.dev_loader = None

        if test_set:
            self.test_loader = DataLoader(test_set, batch_size=config['batch_size'], shuffle=False,
                                          collate_fn=lambda x: x, pin_memory=True)
            self._n_test_batches = len(test_set) // config['batch_size']
            self._n_test_examples = len(test_set)
        else:
            self.test_loader = None

        self._n_train_examples = 0
        self.model = Model(config, train_set)
        self.model.network = self.model.network.to(self.device)
        self.config = self.model.config
        self.is_test = False
Example #8
0
class ModelHandler(object):
    """High level model_handler that trains/validates/tests the network,
    tracks and logs metrics.
    """

    def __init__(self, config):
        self.logger = ModelLogger(config, dirname=config['dir'], pretrained=config['pretrained'])
        self.dirname = self.logger.dirname
        cuda = config['cuda']
        cuda_id = config['cuda_id']
        if not cuda:
            self.device = torch.device('cpu')
        else:
            self.device = torch.device('cuda' if cuda_id < 0 else 'cuda:%d' % cuda_id)
        print("preparing datasets...")
        datasets = None # prepare_datasets(config)
        print(config["pretrained"])

        datasets = None

        if config["pretrained"] == "rc_models_20":
            # baseline
            with open("data/coqa_baseline_data.pkl", "rb") as f_in:
                datasets = pickle.load(f_in)

        elif config["pretrained"] == "rc_models_modified_20":
            # samo zgodovina
            with open("data/coqa_history_data.pkl", "rb") as f_in:
                datasets = pickle.load(f_in)

        elif config["pretrained"] == "rc_models_modified_full_20":
            # trenutno vprasanje in odgovor
            with open("data/coqa_full_data.pkl", "rb") as f_in:
                datasets = pickle.load(f_in)

        elif config["pretrained"] == "rc_models_modified_full_noA_20":
            # trenutno vprasanje
            with open("data/coqa_full_noA_data.pkl", "rb") as f_in:
                datasets = pickle.load(f_in)

        else:
            print("not a valid pretrained model")
            exit()


        train_set = None #datasets['train']
        dev_set = None #datasets['dev']
        test_set = datasets['dev']
        print("datasets prepared")

        # print(train_set.examples[:5])
        exit()

        # Evaluation Metrics:
        self._train_loss = AverageMeter()
        self._train_f1 = AverageMeter()
        self._train_em = AverageMeter()
        self._dev_f1 = AverageMeter()
        self._dev_em = AverageMeter()

        if train_set:
            self.train_loader = DataLoader(train_set, batch_size=config['batch_size'],
                                           shuffle=config['shuffle'], collate_fn=lambda x: x, pin_memory=True)
            self._n_train_batches = len(train_set) // config['batch_size']
        else:
            self.train_loader = None

        if dev_set:
            self.dev_loader = DataLoader(dev_set, batch_size=config['batch_size'],
                                         shuffle=False, collate_fn=lambda x: x, pin_memory=True)
            self._n_dev_batches = len(dev_set) // config['batch_size']
        else:
            self.dev_loader = None

        if test_set:
            self.test_loader = DataLoader(test_set, batch_size=config['batch_size'], shuffle=False,
                                          collate_fn=lambda x: x, pin_memory=True)
            self._n_test_batches = len(test_set) // config['batch_size']
            self._n_test_examples = len(test_set)
        else:
            self.test_loader = None

        self._n_train_examples = 0
        self.model = Model(config, train_set)
        self.model.network = self.model.network.to(self.device)
        self.config = self.model.config
        self.is_test = False

    def train(self):
        if self.train_loader is None or self.dev_loader is None:
            print("No training set or dev set specified -- skipped training.")
            return

        self.is_test = False
        timer = Timer("Train")
        self._epoch = self._best_epoch = 0

        if self.dev_loader is not None:
            print("\n>>> Dev Epoch: [{} / {}]".format(self._epoch, self.config['max_epochs']))
            self._run_epoch(self.dev_loader, training=False, verbose=self.config['verbose'])
            timer.interval("Validation Epoch {}".format(self._epoch))
            format_str = "Validation Epoch {} -- F1: {:0.2f}, EM: {:0.2f} --"
            print(format_str.format(self._epoch, self._dev_f1.mean(), self._dev_em.mean()))

        self._best_f1 = self._dev_f1.mean()
        self._best_em = self._dev_em.mean()
        if self.config['save_params']:
            self.model.save(self.dirname)
        self._reset_metrics()

        while self._stop_condition(self._epoch):
            self._epoch += 1

            print("\n>>> Train Epoch: [{} / {}]".format(self._epoch, self.config['max_epochs']))
            self._run_epoch(self.train_loader, training=True, verbose=self.config['verbose'])
            train_epoch_time = timer.interval("Training Epoch {}".format(self._epoch))
            format_str = "Training Epoch {} -- Loss: {:0.4f}, F1: {:0.2f}, EM: {:0.2f} --"
            print(format_str.format(self._epoch, self._train_loss.mean(),
                  self._train_f1.mean(), self._train_em.mean()))

            print("\n>>> Dev Epoch: [{} / {}]".format(self._epoch, self.config['max_epochs']))
            self._run_epoch(self.dev_loader, training=False, verbose=self.config['verbose'])
            timer.interval("Validation Epoch {}".format(self._epoch))
            format_str = "Validation Epoch {} -- F1: {:0.2f}, EM: {:0.2f} --"
            print(format_str.format(self._epoch, self._dev_f1.mean(), self._dev_em.mean()))

            if self._best_f1 <= self._dev_f1.mean():  # Can be one of loss, f1, or em.
                self._best_epoch = self._epoch
                self._best_f1 = self._dev_f1.mean()
                self._best_em = self._dev_em.mean()
                if self.config['save_params']:
                    self.model.save(self.dirname)
                print("!!! Updated: F1: {:0.2f}, EM: {:0.2f}".format(self._best_f1, self._best_em))

            self._reset_metrics()
            self.logger.log(self._train_loss.last, Constants._TRAIN_LOSS_EPOCH_LOG)
            self.logger.log(self._train_f1.last, Constants._TRAIN_F1_EPOCH_LOG)
            self.logger.log(self._train_em.last, Constants._TRAIN_EM_EPOCH_LOG)
            self.logger.log(self._dev_f1.last, Constants._DEV_F1_EPOCH_LOG)
            self.logger.log(self._dev_em.last, Constants._DEV_EM_EPOCH_LOG)
            self.logger.log(train_epoch_time, Constants._TRAIN_EPOCH_TIME_LOG)

        timer.finish()
        self.training_time = timer.total

        print("Finished Training: {}".format(self.dirname))
        print(self.summary())

    def test(self):
        if self.test_loader is None:
            print("No testing set specified -- skipped testing.")
            return

        self.is_test = True
        self._reset_metrics()
        timer = Timer("Test")
        output = self._run_epoch(self.test_loader, training=False, verbose=0,
                                 out_predictions=self.config['out_predictions'])

        for ex in output:
            _id = ex['id']
            ex['id'] = _id[0]
            ex['turn_id'] = _id[1]

        if self.config['out_predictions']:
            output_file = os.path.join(self.dirname, Constants._PREDICTION_FILE)
            with open(output_file, 'w') as outfile:
                json.dump(output, outfile, indent=4)

        test_f1 = self._dev_f1.mean()
        test_em = self._dev_em.mean()

        timer.finish()
        print(self.report(self._n_test_batches, None, test_f1, test_em, mode='test'))
        self.logger.log([test_f1, test_em], Constants._TEST_EVAL_LOG)
        print("Finished Testing: {}".format(self.dirname))

    def _run_epoch(self, data_loader, training=True, verbose=10, out_predictions=False):
        start_time = time.time()
        output = []
        for step, input_batch in enumerate(data_loader):
            input_batch = sanitize_input(input_batch, self.config, self.model.word_dict,
                                         self.model.feature_dict, training=training)
            x_batch = vectorize_input(input_batch, self.config, training=training, device=self.device)
            if not x_batch:
                continue  # When there are no target spans present in the batch
            # print("running train predictions")
            res = self.model.predict(x_batch, update=training, out_predictions=out_predictions)

            loss = res['loss']
            f1 = res['f1']
            em = res['em']
            # print("updating metrics")
            self._update_metrics(loss, f1, em, x_batch['batch_size'], training=training)

            if training:
                self._n_train_examples += x_batch['batch_size']

            if (verbose > 0) and (step % verbose == 0):
                mode = "train" if training else ("test" if self.is_test else "dev")
                print(self.report(step, loss, f1 * 100, em * 100, mode))
                print('used_time: {:0.2f}s'.format(time.time() - start_time))

            if out_predictions:
                for id, prediction, span in zip(input_batch['id'], res['predictions'], res['spans']):
                    output.append({'id': id,
                                   'answer': prediction,
                                   'span_start': span[0],
                                   'span_end': span[1]})
        return output

    def report(self, step, loss, f1, em, mode='train'):
        if mode == "train":
            format_str = "[train-{}] step: [{} / {}] | exs = {} | loss = {:0.4f} | f1 = {:0.2f} | em = {:0.2f}"
            return format_str.format(self._epoch, step, self._n_train_batches, self._n_train_examples, loss, f1, em)
        elif mode == "dev":
            return "[predict-{}] step: [{} / {}] | f1 = {:0.2f} | em = {:0.2f}".format(
                    self._epoch, step, self._n_dev_batches, f1, em)
        elif mode == "test":
            return "[test] | test_exs = {} | step: [{} / {}] | f1 = {:0.2f} | em = {:0.2f}".format(
                    self._n_test_examples, step, self._n_test_batches, f1, em)
        else:
            raise ValueError('mode = {} not supported.' % mode)

    def summary(self):
        start = " <<<<<<<<<<<<<<<< MODEL SUMMARY >>>>>>>>>>>>>>>> "
        info = "Best epoch = {}\nDev F1 = {:0.2f}\nDev EM = {:0.2f}".format(
            self._best_epoch, self._best_f1, self._best_em)
        end = " <<<<<<<<<<<<<<<< MODEL SUMMARY >>>>>>>>>>>>>>>> "
        return "\n".join([start, info, end])

    def _update_metrics(self, loss, f1, em, batch_size, training=True):
        if training:
            self._train_loss.update(loss)
            self._train_f1.update(f1 * 100, batch_size)
            self._train_em.update(em * 100, batch_size)
        else:
            self._dev_f1.update(f1 * 100, batch_size)
            self._dev_em.update(em * 100, batch_size)

    def _reset_metrics(self):
        self._train_loss.reset()
        self._train_f1.reset()
        self._train_em.reset()
        self._dev_f1.reset()
        self._dev_em.reset()

    def _stop_condition(self, epoch):
        """
        Checks have not exceeded max epochs and has not gone 10 epochs without improvement.
        """
        no_improvement = epoch >= self._best_epoch + 10
        exceeded_max_epochs = epoch >= self.config['max_epochs']
        return False if exceeded_max_epochs or no_improvement else True
Example #9
0
class ModelHandler(object):
    """High level model_handler that trains/validates/tests the network,
    tracks and logs metrics.
    """
    def __init__(self, config):
        self.logger = ModelLogger(config,
                                  dirname=config['dir'],
                                  pretrained=config['pretrained'])
        self.dirname = self.logger.dirname
        cuda = config['cuda']
        cuda_id = config['cuda_id']
        if not cuda:
            self.device = torch.device('cpu')
        else:
            self.device = torch.device('cuda' if cuda_id < 0 else 'cuda:%d' %
                                       cuda_id)

        datasets = prepare_datasets(config)
        train_set = datasets['train']
        dev_set = datasets['dev']
        test_set = datasets['test']

        # Evaluation Metrics:
        self._train_loss = AverageMeter()
        self._train_f1 = AverageMeter()
        self._train_em = AverageMeter()
        self._dev_f1 = AverageMeter()
        self._dev_em = AverageMeter()

        if train_set:
            self.train_loader = DataLoader(train_set,
                                           batch_size=config['batch_size'],
                                           shuffle=config['shuffle'],
                                           collate_fn=lambda x: x,
                                           pin_memory=True)
            self._n_train_batches = len(train_set) // config['batch_size']
        else:
            self.train_loader = None

        if dev_set:
            self.dev_loader = DataLoader(dev_set,
                                         batch_size=config['batch_size'],
                                         shuffle=False,
                                         collate_fn=lambda x: x,
                                         pin_memory=True)
            self._n_dev_batches = len(dev_set) // config['batch_size']
        else:
            self.dev_loader = None

        if test_set:
            self.test_loader = DataLoader(test_set,
                                          batch_size=config['batch_size'],
                                          shuffle=False,
                                          collate_fn=lambda x: x,
                                          pin_memory=True)
            self._n_test_batches = len(test_set) // config['batch_size']
            self._n_test_examples = len(test_set)
        else:
            self.test_loader = None

        self._n_train_examples = 0
        self.model = Model(config, train_set)
        self.model.network = self.model.network.to(self.device)
        self.config = self.model.config
        self.is_test = False

    def train(self):
        if self.train_loader is None or self.dev_loader is None:
            print("No training set or dev set specified -- skipped training.")
            return

        self.is_test = False
        timer = Timer("Train")
        self._epoch = self._best_epoch = 0

        if self.dev_loader is not None:
            print("\n>>> Dev Epoch: [{} / {}]".format(
                self._epoch, self.config['max_epochs']))
            self._run_epoch(self.dev_loader,
                            training=False,
                            verbose=self.config['verbose'])
            timer.interval("Validation Epoch {}".format(self._epoch))
            format_str = "Validation Epoch {} -- F1: {:0.2f}, EM: {:0.2f} --"
            print(
                format_str.format(self._epoch, self._dev_f1.mean(),
                                  self._dev_em.mean()))

        self._best_f1 = self._dev_f1.mean()
        self._best_em = self._dev_em.mean()
        if self.config['save_params']:
            self.model.save(self.dirname)
        self._reset_metrics()

        while self._stop_condition(self._epoch):
            self._epoch += 1

            print("\n>>> Train Epoch: [{} / {}]".format(
                self._epoch, self.config['max_epochs']))
            self._run_epoch(self.train_loader,
                            training=True,
                            verbose=self.config['verbose'])
            train_epoch_time = timer.interval("Training Epoch {}".format(
                self._epoch))
            format_str = "Training Epoch {} -- Loss: {:0.4f}, F1: {:0.2f}, EM: {:0.2f} --"
            print(
                format_str.format(self._epoch, self._train_loss.mean(),
                                  self._train_f1.mean(),
                                  self._train_em.mean()))

            print("\n>>> Dev Epoch: [{} / {}]".format(
                self._epoch, self.config['max_epochs']))
            self._run_epoch(self.dev_loader,
                            training=False,
                            verbose=self.config['verbose'])
            timer.interval("Validation Epoch {}".format(self._epoch))
            format_str = "Validation Epoch {} -- F1: {:0.2f}, EM: {:0.2f} --"
            print(
                format_str.format(self._epoch, self._dev_f1.mean(),
                                  self._dev_em.mean()))

            if self._best_f1 <= self._dev_f1.mean(
            ):  # Can be one of loss, f1, or em.
                self._best_epoch = self._epoch
                self._best_f1 = self._dev_f1.mean()
                self._best_em = self._dev_em.mean()
                if self.config['save_params']:
                    self.model.save(self.dirname)
                print("!!! Updated: F1: {:0.2f}, EM: {:0.2f}".format(
                    self._best_f1, self._best_em))

            self._reset_metrics()
            self.logger.log(self._train_loss.last,
                            Constants._TRAIN_LOSS_EPOCH_LOG)
            self.logger.log(self._train_f1.last, Constants._TRAIN_F1_EPOCH_LOG)
            self.logger.log(self._train_em.last, Constants._TRAIN_EM_EPOCH_LOG)
            self.logger.log(self._dev_f1.last, Constants._DEV_F1_EPOCH_LOG)
            self.logger.log(self._dev_em.last, Constants._DEV_EM_EPOCH_LOG)
            self.logger.log(train_epoch_time, Constants._TRAIN_EPOCH_TIME_LOG)

        timer.finish()
        self.training_time = timer.total

        print("Finished Training: {}".format(self.dirname))
        print(self.summary())

    def test(self):
        if self.test_loader is None:
            print("No testing set specified -- skipped testing.")
            return

        self.is_test = True
        self._reset_metrics()
        timer = Timer("Test")
        output = self._run_epoch(
            self.test_loader,
            training=False,
            verbose=0,
            out_predictions=self.config['out_predictions'],
            out_attentions=self.config['save_attn_weights'])

        if self.config['dialog_batched']:
            # Slightly different id format
            _id = None
            turn = 0
            for ex in output:
                if ex['id'] != _id:
                    _id = ex['id']
                    turn = 0
                ex['id'] = _id
                ex['turn_id'] = turn
                turn += 1
        else:
            for ex in output:
                _id = ex['id']
                ex['id'] = _id[0]
                ex['turn_id'] = _id[1]

        if self.config['out_predictions']:
            output_file = os.path.join(self.dirname,
                                       Constants._PREDICTION_FILE)
            with open(output_file, 'w') as outfile:
                json.dump(output, outfile, indent=4)
                if self.config['out_predictions_csv']:
                    import pandas as pd
                    for o in output:
                        o['gold_answer_1'], o['gold_answer_2'], o[
                            'gold_answer_3'], o['gold_answer_4'] = o[
                                'gold_answers']
                    output_csv = pd.DataFrame(output)
                    output_csv = output_csv[[
                        'id', 'turn_id', 'span_start', 'span_end', 'answer',
                        'gold_answer_1', 'gold_answer_2', 'gold_answer_3',
                        'gold_answer_4', 'f1', 'em'
                    ]]
                    output_csv.to_csv(output_file.replace('.json', '.csv'),
                                      index=False)

        test_f1 = self._dev_f1.mean()
        test_em = self._dev_em.mean()

        timer.finish()
        print(
            self.report(self._n_test_batches,
                        None,
                        test_f1,
                        test_em,
                        mode='test'))
        self.logger.log([test_f1, test_em], Constants._TEST_EVAL_LOG)
        print("Finished Testing: {}".format(self.dirname))

    def _run_epoch(self,
                   data_loader,
                   training=True,
                   verbose=10,
                   out_predictions=False,
                   out_attentions=None):
        start_time = time.time()
        output = []
        for step, input_batch in enumerate(data_loader):
            if self.config['dialog_batched']:
                x_batches = []
                for ib in input_batch:
                    ib_sanitized = sanitize_input_dialog_batched(
                        ib,
                        self.config,
                        self.model.word_dict,
                        self.model.feature_dict,
                        training=training)
                    x_batch = vectorize_input_dialog_batched(
                        ib_sanitized,
                        self.config,
                        training=training,
                        device=self.device)
                    if not x_batch:
                        continue  # When there are no target spans present in the batch
                    x_batches.append(x_batch)
            else:
                input_batch = sanitize_input(input_batch,
                                             self.config,
                                             self.model.word_dict,
                                             self.model.feature_dict,
                                             training=training)
                x_batch = vectorize_input(input_batch,
                                          self.config,
                                          training=training,
                                          device=self.device)
                if not x_batch:
                    continue  # When there are no target spans present in the batch
                x_batches = [x_batch]  # Singleton list.

            res = self.model.predict(x_batches,
                                     update=training,
                                     out_predictions=out_predictions,
                                     out_attentions=out_attentions)

            loss = res['loss']
            f1 = res['f1']
            em = res['em']
            total_ex = sum(xb['batch_size'] for xb in x_batches)
            self._update_metrics(loss, f1, em, total_ex, training=training)

            if training:
                self._n_train_examples += total_ex

            if (verbose > 0) and (step % verbose == 0):
                mode = "train" if training else (
                    "test" if self.is_test else "dev")
                print(self.report(step, loss, f1 * 100, em * 100, mode))
                print('used_time: {:0.2f}s'.format(time.time() - start_time))

            if out_predictions:
                for id, prediction, span, f1, em, ans in zip(
                        res['ids'], res['predictions'], res['spans'],
                        res['f1s'], res['ems'], res['answers']):
                    output.append({
                        'id': id,
                        'answer': prediction,
                        'span_start': span[0],
                        'span_end': span[1],
                        'f1': f1,
                        'em': em,
                        'gold_answers': ans
                    })
        return output

    def report(self, step, loss, f1, em, mode='train'):
        if mode == "train":
            format_str = "[train-{}] step: [{} / {}] | exs = {} | loss = {:0.4f} | f1 = {:0.2f} | em = {:0.2f}"
            return format_str.format(self._epoch, step, self._n_train_batches,
                                     self._n_train_examples, loss, f1, em)
        elif mode == "dev":
            return "[predict-{}] step: [{} / {}] | f1 = {:0.2f} | em = {:0.2f}".format(
                self._epoch, step, self._n_dev_batches, f1, em)
        elif mode == "test":
            return "[test] | test_exs = {} | step: [{} / {}] | f1 = {:0.2f} | em = {:0.2f}".format(
                self._n_test_examples, step, self._n_test_batches, f1, em)
        else:
            raise ValueError('mode = {} not supported.' % mode)

    def summary(self):
        start = " <<<<<<<<<<<<<<<< MODEL SUMMARY >>>>>>>>>>>>>>>> "
        info = "Best epoch = {}\nDev F1 = {:0.2f}\nDev EM = {:0.2f}".format(
            self._best_epoch, self._best_f1, self._best_em)
        end = " <<<<<<<<<<<<<<<< MODEL SUMMARY >>>>>>>>>>>>>>>> "
        return "\n".join([start, info, end])

    def _update_metrics(self, loss, f1, em, batch_size, training=True):
        if training:
            self._train_loss.update(loss, batch_size)
            self._train_f1.update(f1 * 100, batch_size)
            self._train_em.update(em * 100, batch_size)
        else:
            self._dev_f1.update(f1 * 100, batch_size)
            self._dev_em.update(em * 100, batch_size)

    def _reset_metrics(self):
        self._train_loss.reset()
        self._train_f1.reset()
        self._train_em.reset()
        self._dev_f1.reset()
        self._dev_em.reset()

    def _stop_condition(self, epoch):
        """
        Checks have not exceeded max epochs and has not gone 10 epochs without improvement.
        """
        no_improvement = epoch >= self._best_epoch + 10
        exceeded_max_epochs = epoch >= self.config['max_epochs']
        return False if exceeded_max_epochs or no_improvement else True
    def train(self):
        """ Train the network.
        :return: None
        """
        train_iter = iter(self.train_loader)
        valid_iter = iter(self.valid_loader)

        global_step = warmup_step = 0
        train_step = valid_step = 0
        train_epoch = valid_epoch = 0
        curr_train_epoch = 0

        # Keep track of losses.
        losses = AverageMeter()
        val_losses = []
        best_val_loss = float("inf")
        best_val_epoch = 0
        try:
            self.evaluate_fnc(train_epoch)
        except:
            raise Exception("Please check your evaluation function {}.".format(str(self.evaluate_fnc)))

        while train_epoch < self.warmup_epochs:
            # Reset the data augmentation parameters.
            self.train_loader.dataset.reset_hyper_params()

            perturbed_h_tensor = self.h_container.get_perturbed_hyper(self.train_loader.batch_size)

            # Set the data augmentation hyperparameters.
            self.train_loader.dataset.set_h_container(self.h_container, perturbed_h_tensor)
            inputs, augmented_inputs, labels, train_iter, train_epoch = \
                next_batch(train_iter, self.train_loader, train_epoch, self.device)

            if curr_train_epoch != train_epoch:
                # When train_epoch changes, evaluate validation & test losses.
                val_loss = self.evaluate_fnc(train_epoch, losses.avg)
                val_losses.append(val_loss)

                losses.reset()
                curr_train_epoch = train_epoch

                self.lr_step(val_loss)

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_val_epoch = curr_train_epoch
                wandb.log({
                        "best_val_loss": best_val_loss,
                        "best_val_epoch": best_val_epoch})

            # Taking care of the last batch.
            if inputs.size(0) != self.train_loader.batch_size:
                perturbed_h_tensor = perturbed_h_tensor[:inputs.size(0), :]

            _, loss = self.step_optimizer.step(inputs, labels, perturbed_h_tensor=perturbed_h_tensor,
                                               augmented_inputs=augmented_inputs, tune_hyper=False)
            losses.update(loss.item(), inputs.size(0))

            if warmup_step % self.log_interval == 0 and global_step > 0:
                print("Global Step: {} Train Epoch: {} Warmup step: {} Loss: {:.3f}".format(
                    global_step, train_epoch, warmup_step, loss))

            warmup_step += 1
            global_step += 1

        print("Warm-up finished.")
        if self.patience is None:
            self.patience = self.total_epochs

        patience_elapsed = 0
        while patience_elapsed < self.patience and train_epoch < self.total_epochs:
            for _ in range(self.train_steps):
                # Perform training steps:
                self.train_loader.dataset.reset_hyper_params()
                perturbed_h_tensor = self.h_container.get_perturbed_hyper(self.train_loader.batch_size)
                self.train_loader.dataset.set_h_container(self.h_container, perturbed_h_tensor)
                inputs, augmented_inputs, labels, train_iter, train_epoch = \
                    next_batch(train_iter, self.train_loader, train_epoch, self.device)

                if curr_train_epoch != train_epoch:
                    val_loss = self.evaluate_fnc(train_epoch, losses.avg)
                    val_losses.append(val_loss)
                    losses.reset()
                    curr_train_epoch = train_epoch

                    self.lr_step(val_loss)

                    if val_loss < best_val_loss:
                        best_val_loss = val_loss
                        best_val_epoch = curr_train_epoch
                        patience_elapsed = 0
                    else:
                        patience_elapsed += 1
                    wandb.log(
                        {"best_val_loss": best_val_loss, "best_val_epoch": best_val_epoch}
                    )

                # Again, take care of the last batch.
                if inputs.size(0) != self.train_loader.batch_size:
                    perturbed_h_tensor = perturbed_h_tensor[:inputs.size(0), :]

                _, loss = self.step_optimizer.step(inputs, labels, perturbed_h_tensor=perturbed_h_tensor,
                                                   augmented_inputs=augmented_inputs, tune_hyper=False)
                losses.update(loss.item(), inputs.size(0))

                if train_step % self.log_interval == 0 and global_step > 0:
                    print(
                        "Train - Global Step: {} Train Epoch: {} Train step:{} "
                        "Loss: {:.3f}".format(
                            global_step, train_epoch, train_step, loss))
                train_step += 1
                global_step += 1

            for _ in range(self.valid_steps):
                inputs, _, labels, valid_iter, valid_epoch = \
                    next_batch(valid_iter, self.valid_loader, valid_epoch, self.device)
                perturbed_h_tensor = self.h_container.get_perturbed_hyper(inputs.size(0))

                _, loss = self.step_optimizer.step(inputs, labels, perturbed_h_tensor=perturbed_h_tensor,
                                                   augmented_inputs=None, tune_hyper=True)

                if valid_step % self.log_interval == 0 and global_step > 0:
                    print(
                        "Valid - Global Step: {} Valid Epoch: {} Valid step:{} "
                        "Loss: {:.3f}".format(global_step, valid_epoch, valid_step, loss))

                wandb.log(self.h_container.generate_summary())
                valid_step += 1
                global_step += 1
class ModelHandler():
    def __init__(self, config):
        self.config = config
        tokenizer_model = MODELS[config['model_name']]
        self.train_loader, self.dev_loader, tokenizer = prepare_datasets(
            config, tokenizer_model)
        self._n_dev_batches = len(
            self.dev_loader.dataset) // config['batch_size']
        self._n_train_batches = len(
            self.train_loader.dataset) // config['batch_size']
        if config['cuda']:
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')

        self._train_loss = AverageMeter()
        self._train_f1 = AverageMeter()
        self._train_em = AverageMeter()
        self._dev_f1 = AverageMeter()
        self._dev_em = AverageMeter()

        self.model = Model(config, MODELS[config['model_name']], self.device,
                           tokenizer).to(self.device)
        t_total = len(
            self.train_loader
        ) // config['gradient_accumulation_steps'] * config['max_epochs']
        self.optimizer = AdamW(self.model.parameters(),
                               lr=config['lr'],
                               eps=config['adam_epsilon'])
        self.optimizer.zero_grad()
        self._n_train_examples = 0
        self._epoch = self._best_epoch = 0
        self._best_f1 = 0
        self._best_em = 0
        self.restored = False
        if config['pretrained_dir'] is not None:
            self.restore()

    def train(self):
        if not self.restored:
            print("\n>>> Dev Epoch: [{} / {}]".format(
                self._epoch, self.config['max_epochs']))
            self._run_epoch(self.dev_loader,
                            training=False,
                            verbose=self.config['verbose'],
                            save=False)

            format_str = "Validation Epoch {} -- F1: {:0.2f}, EM: {:0.2f} --"
            print(
                format_str.format(self._epoch, self._dev_f1.mean(),
                                  self._dev_em.mean()))
            self._best_f1 = self._dev_f1.mean()
            self._best_em = self._dev_em.mean()
        while self._stop_condition(self._epoch):
            self._epoch += 1
            print("\n>>> Train Epoch: [{} / {}]".format(
                self._epoch, self.config['max_epochs']))
            if not self.restored:
                self.train_loader.prepare()
            self.restored = False

            self._run_epoch(self.train_loader,
                            training=True,
                            verbose=self.config['verbose'])
            format_str = "Training Epoch {} -- Loss: {:0.4f}, F1: {:0.2f}, EM: {:0.2f} --"
            print(
                format_str.format(self._epoch, self._train_loss.mean(),
                                  self._train_f1.mean(),
                                  self._train_em.mean()))
            print("\n>>> Dev Epoch: [{} / {}]".format(
                self._epoch, self.config['max_epochs']))
            self.dev_loader.prepare()
            self._run_epoch(self.dev_loader,
                            training=False,
                            verbose=self.config['verbose'],
                            save=False)
            format_str = "Validation Epoch {} -- F1: {:0.2f}, EM: {:0.2f} --"
            print(
                format_str.format(self._epoch, self._dev_f1.mean(),
                                  self._dev_em.mean()))

            if self._best_f1 <= self._dev_f1.mean():
                self._best_epoch = self._epoch
                self._best_f1 = self._dev_f1.mean()
                self._best_em = self._dev_em.mean()
                print("!!! Updated: F1: {:0.2f}, EM: {:0.2f}".format(
                    self._best_f1, self._best_em))
            self._reset_metrics()
            self.save(self._epoch)

    def restore(self):
        if not os.path.exists(self.config['pretrained_dir']):
            print('dir doesn\'t exists, cannot restore')
            return
        restored_params = torch.load(self.config['pretrained_dir'] +
                                     '/latest/model.pth')
        self.model.load_state_dict(restored_params['model'])
        self.optimizer.load_state_dict(restored_params['optimizer'])
        self._epoch = restored_params['epoch']
        self._best_epoch = restored_params['best_epoch']
        self._n_train_examples = restored_params['train_examples']
        self._best_f1 = restored_params['best_f1']
        self._best_em = restored_params['best_em']
        examples = restored_params['dataloader_examples']
        batch_state = restored_params['dataloader_batch_state']
        state = restored_params['dataloader_state']
        self.train_loader.restore(examples, state, batch_state)

        self.restored = True

    def save(self, save_epoch_val):
        if not os.path.exists(self.config['save_state_dir']):
            os.mkdir(self.config['save_state_dir'])

        if self._best_epoch == self._epoch:
            if not os.path.exists(self.config['save_state_dir'] + '/best'):
                os.mkdir(self.config['save_state_dir'] + '/best')
            save_dic = {
                'epoch': self._epoch,
                'best_epoch': self._best_epoch,
                'train_examples': self._n_train_examples,
                'model': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_f1': self._best_f1,
                'best_em': self._best_em
            }
            torch.save(save_dic,
                       self.config['save_state_dir'] + '/best/model.pth')
        if not os.path.exists(self.config['save_state_dir'] + '/latest'):
            os.mkdir(self.config['save_state_dir'] + '/latest')
        save_dic = {
            'epoch': save_epoch_val,
            'best_epoch': self._best_epoch,
            'train_examples': self._n_train_examples,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'best_f1': self._best_f1,
            'best_em': self._best_em,
            'dataloader_batch_state': self.train_loader.batch_state,
            'dataloader_state': self.train_loader.state,
            'dataloader_examples': self.train_loader.examples
        }
        torch.save(save_dic,
                   self.config['save_state_dir'] + '/latest/model.pth')

    def _run_epoch(self,
                   data_loader,
                   training=True,
                   verbose=10,
                   out_predictions=False,
                   save=True):
        start_time = time.time()
        while data_loader.batch_state < len(data_loader):
            input_batch = data_loader.get()
            res = self.model(input_batch, training)
            tr_loss = 0
            if training:
                loss = res['loss']
                if self.config['gradient_accumulation_steps'] > 1:
                    loss = loss / self.config['gradient_accumulation_steps']
                tr_loss = loss.mean().item()
            start_logits = res['start_logits']
            end_logits = res['end_logits']

            if training:
                self.model.update(loss, self.optimizer,
                                  data_loader.batch_state)
            paragraphs = [inp['tokens'] for inp in input_batch]
            answers = [inp['answer'] for inp in input_batch]
            f1, em = self.model.evaluate(start_logits, end_logits, paragraphs,
                                         answers)

            self._update_metrics(tr_loss,
                                 f1,
                                 em,
                                 len(paragraphs),
                                 training=training)

            if training:
                self._n_train_examples += len(paragraphs)
            if (verbose > 0) and (data_loader.batch_state % verbose == 0):
                if save:
                    self.save(self._epoch - 1)
                mode = "train" if training else "dev"
                print(
                    self.report(data_loader.batch_state, tr_loss, f1 * 100,
                                em * 100, mode))
                print('used_time: {:0.2f}s'.format(time.time() - start_time))

    def _update_metrics(self, loss, f1, em, batch_size, training=True):
        if training:
            self._train_loss.update(loss)
            self._train_f1.update(f1 * 100, batch_size)
            self._train_em.update(em * 100, batch_size)
        else:
            self._dev_f1.update(f1 * 100, batch_size)
            self._dev_em.update(em * 100, batch_size)

    def _reset_metrics(self):
        self._train_loss.reset()
        self._train_f1.reset()
        self._train_em.reset()
        self._dev_f1.reset()
        self._dev_em.reset()

    def report(self, step, loss, f1, em, mode='train'):
        if mode == "train":
            format_str = "[train-{}] step: [{} / {}] | exs = {} | loss = {:0.4f} | f1 = {:0.2f} | em = {:0.2f}"
            return format_str.format(self._epoch, step, self._n_train_batches,
                                     self._n_train_examples, loss, f1, em)
        elif mode == "dev":
            return "[predict-{}] step: [{} / {}] | f1 = {:0.2f} | em = {:0.2f}".format(
                self._epoch, step, self._n_dev_batches, f1, em)
        elif mode == "test":
            return "[test] | test_exs = {} | step: [{} / {}] | f1 = {:0.2f} | em = {:0.2f}".format(
                self._n_test_examples, step, self._n_test_batches, f1, em)
        else:
            raise ValueError('mode = {} not supported.' % mode)

    def _stop_condition(self, epoch):
        """
		Checks have not exceeded max epochs and has not gone 10 epochs without improvement.
		"""
        no_improvement = epoch >= self._best_epoch + 10
        exceeded_max_epochs = epoch >= self.config['max_epochs']
        return False if exceeded_max_epochs or no_improvement else True
Example #12
0
    def __init__(self, config):
        self.logger = ModelLogger(config,
                                  dirname=config['dir'],
                                  pretrained=config['pretrained'])
        self.dirname = self.logger.dirname
        #self.dirname=config["pretrained_model"]
        cuda = config['cuda']
        cuda_id = config['cuda_id']
        if not cuda:
            self.device = torch.device('cpu')
        else:
            self.device = torch.device('cuda' if cuda_id < 0 else 'cuda:%d' %
                                       cuda_id)

        #データの読み込み
        datasets = prepare_datasets(config)
        train_set = datasets['train']
        dev_set = datasets['dev']
        test_set = datasets['test']

        # Evaluation Metrics:
        self._train_loss = AverageMeter()
        self._train_f1 = AverageMeter()
        self._train_em = AverageMeter()
        self._dev_f1 = AverageMeter()
        self._dev_em = AverageMeter()

        #データのロード
        if train_set:
            self.train_loader = DataLoader(train_set,
                                           batch_size=config['batch_size'],
                                           shuffle=config['shuffle'],
                                           collate_fn=lambda x: x,
                                           pin_memory=True)
            self._n_train_batches = len(train_set) // config['batch_size']
        else:
            self.train_loader = None

        if dev_set:
            self.dev_loader = DataLoader(dev_set,
                                         batch_size=config['batch_size'],
                                         shuffle=False,
                                         collate_fn=lambda x: x,
                                         pin_memory=True)
            self._n_dev_batches = len(dev_set) // config['batch_size']
        else:
            self.dev_loader = None

        if test_set:
            self.test_loader = DataLoader(test_set,
                                          batch_size=config['batch_size'],
                                          shuffle=False,
                                          collate_fn=lambda x: x,
                                          pin_memory=True)
            self._n_test_batches = len(test_set) // config['batch_size']
            self._n_test_examples = len(test_set)
        else:
            self.test_loader = None

        #モデルの用意
        self._n_train_examples = 0
        self.model = Model(config, train_set)
        self.model.network = self.model.network.to(self.device)
        self.config = self.model.config
        self.is_test = False
        self.textlogger = get_logger("log.txt")
Example #13
0
    def validation(self, epoch):
        self.model.eval()
        for e in self.evaluator:
            e.reset()

        confidence_meter = []
        for _ in range(self.args.C):
            confidence_meter.append(AverageMeter())

        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                outputs = self.model(image)
            loss = []
            for classifier_i in range(self.args.C):
                loss.append(self.criterion(outputs[classifier_i], target))

            loss = sum(loss) / (self.args.C)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))

            target_show = target

            prediction = []
            """ Add batch sample into evaluator """
            for classifier_i in range(self.args.C):
                pred = torch.argmax(outputs[classifier_i], axis=1)
                prediction.append(pred)
                self.evaluator[classifier_i].add_batch(
                    target, prediction[classifier_i])
                confidence = normalized_shannon_entropy(outputs[classifier_i])
                confidence_meter[classifier_i].update(confidence)

            if epoch // 100 == i:
                global_step = epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target_show, outputs[-1],
                                             global_step)

        mIoU = []
        mean_confidence = []
        for classifier_i, e in enumerate(self.evaluator):
            mIoU.append(e.Mean_Intersection_over_Union())
            self.writer.add_scalar(
                'val/classifier_' + str(classifier_i) + '/mIoU',
                mIoU[classifier_i], epoch)
            mean_confidence.append(confidence_meter[classifier_i].average())
            self.writer.add_scalar(
                'val/classifier_' + str(classifier_i) + '/confidence',
                mean_confidence[classifier_i], epoch)

        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.test_batch_size + image.data.shape[0]))
        if self.args.C == 2:
            print("classifier_1_mIoU:{}, classifier_2_mIoU: {}".format(
                mIoU[0], mIoU[1]))
            print("classifier_1_confidence:{}, classifier_2_confidence: {}".
                  format(mean_confidence[0], mean_confidence[1]))
        elif self.args.C == 3:
            print(
                "classifier_1_mIoU:{}, classifier_2_mIoU:{}, classifier_3_mIoU:{}"
                .format(mIoU[0], mIoU[1], mIoU[2]))
            print(
                "classifier_1_confidence:{}, classifier_2_confidence:{}, classifier_3_confidence:{}"
                .format(mean_confidence[0], mean_confidence[1],
                        mean_confidence[2]))
        elif self.args.C == 4:
            print(
                "classifier_1_mIoU:{}, classifier_2_mIoU:{}, classifier_3_mIoU:{}, classifier_4_mIoU:{}"
                .format(mIoU[0], mIoU[1], mIoU[2], mIoU[3]))
            print(
                "classifier_1_confidence:{}, classifier_2_confidence:{}, classifier_3_confidence:{}, classifier_4_confidence:{}"
                .format(mean_confidence[0], mean_confidence[1],
                        mean_confidence[2], mean_confidence[3]))
        print('Loss: %.3f' % test_loss)

        new_pred = sum(mIoU) / self.args.C
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)