示例#1
0
    def validate(self, model: torch.nn.Module, lossfunction: _Loss, iter: Iterator, ema=None, log_results=False) -> \
            Tuple[float, float, float]:
        model.eval()
        if ema is not None:
            backup_params = EMA.ema_backup_and_loadavg(ema, model)

        results = dict()
        ids = []
        lossvalues = []
        spans = []
        gt_spans = []
        span_probs = []
        for i, batch in enumerate(iter):
            ids += batch.id
            logprobs_S, logprobs_E = model(batch)
            loss_s = lossfunction(logprobs_S, batch.a_start)
            loss_e = lossfunction(logprobs_E, batch.a_end)
            loss = loss_s + loss_e
            lossvalues += loss.tolist()

            best_span_probs, candidates = model.decode(logprobs_S, logprobs_E)
            span_probs += best_span_probs.tolist()
            spans += self.get_spans(batch, candidates)
            gt_spans += batch.gt_answer

        # compute the final loss and results
        # we need to filter trhough multiple possible choices and pick the best one
        lossdict = defaultdict(lambda: math.inf)
        probs = defaultdict(lambda: 0)
        for id, value, span, span_prob in zip(ids, lossvalues, spans, span_probs):
            # record only lowest loss
            if lossdict[id] > value:
                lossdict[id] = value
            results[id] = span
            probs[id] = span_prob

        if log_results:
            self.log_results(results, probs)

        loss = sum(lossdict.values()) / len(lossdict)
        prediction_file = f".data/squad/dev_results_{socket.gethostname()}.json"
        with open(prediction_file, "w") as f:
            json.dump(results, f)

        dataset_file = ".data/squad/dev-v1.1.json"

        expected_version = '1.1'
        with open(dataset_file) as dataset_file:
            dataset_json = json.load(dataset_file)
            if (dataset_json['version'] != expected_version):
                logging.info('Evaluation expects v-' + expected_version +
                             ', but got dataset with v-' + dataset_json['version'],
                             file=sys.stderr)
            dataset = dataset_json['data']
        with open(prediction_file) as prediction_file:
            predictions = json.load(prediction_file)
        result = evaluate(dataset, predictions)
        logging.info(json.dumps(result))

        if ema is not None:
            EMA.ema_restore_backed_params(backup_params, model)

        return loss, result["exact_match"], result["f1"]
示例#2
0
    def fit(self, config, device):
        logging.info(json.dumps(config, indent=4, sort_keys=True))

        if config["char_embeddings"]:
            fields = SquadDataset.prepare_fields_char()
        else:
            fields = SquadDataset.prepare_fields()

        train, val = SquadDataset.splits(fields)
        fields = dict(fields)

        fields["question"].build_vocab(train, val, vectors=GloVe(name='6B', dim=config["embedding_size"]))

        if not type(fields["question_char"]) == torchtext.data.field.RawField:
            fields["question_char"].build_vocab(train, val, max_size=config["char_maxsize_vocab"])

        # Make if shuffle
        train_iter = BucketIterator(train, sort_key=lambda x: -(len(x.question) + len(x.document)),
                                    shuffle=True, sort=False, sort_within_batch=True,
                                    batch_size=config["batch_size"], train=True,
                                    repeat=False,
                                    device=device)

        val_iter = BucketIterator(val, sort_key=lambda x: -(len(x.question) + len(x.document)), sort=True,
                                  batch_size=config["batch_size"],
                                  repeat=False,
                                  device=device)
        #
        # model = torch.load(
        #     "saved/65F1_checkpoint_<class 'trainer.ModelFramework'>_L_2.1954014434733815_2019-06-28_10:06_pcknot2.pt").to(
        #     device)
        if config["modelname"] == "baseline":
            model = Baseline(config, fields["question"].vocab).to(device)
        elif config["modelname"] == "bidaf_simplified":
            model = BidafSimplified(config, fields["question"].vocab).to(device)
        elif config["modelname"] == "bidaf":
            model = BidAF(config, fields['question'].vocab, fields["question_char"].vocab).to(device)
        # glorot_param_init(model)
        logging.info(f"Models has {count_parameters(model)} parameters")
        param_sizes, param_shapes = report_parameters(model)
        param_sizes = "\n'".join(str(param_sizes).split(", '"))
        param_shapes = "\n'".join(str(param_shapes).split(", '"))
        logging.debug(f"Model structure:\n{param_sizes}\n{param_shapes}\n")

        if config["optimizer"] == "adam":
            optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()),
                             lr=config["learning_rate"])
        else:
            raise NotImplementedError(f"Option {config['optimizer']} for \"optimizer\" setting is undefined.")

        start_time = time.time()
        try:
            best_val_loss = math.inf
            best_val_f1 = 0
            best_em = 0
            ema_active = False
            for it in range(config["max_iterations"]):
                logging.info(f"Iteration {it}")
                if "ema" in config and config["ema"]:
                    ema = EMA.ema_register(config, model)
                    ema_active = True

                self.train_epoch(model, CrossEntropyLoss(), optimizer, train_iter)

                if ema_active:
                    EMA.ema_update(ema, model)

                validation_loss, em, f1 = self.validate(model, CrossEntropyLoss(reduction='none'), val_iter,
                                                        ema=ema if "ema" in config and config[
                                                            "ema"] and ema_active else None)
                if validation_loss < best_val_loss: best_val_loss = validation_loss
                if f1 > best_val_f1: best_val_f1 = validation_loss
                if em > best_em: best_em = em
                logging.info(f"BEST L/F1/EM = {best_val_loss:.2f}/{best_val_f1:.2f}/{best_em:.2f}")
                if em > 65:
                    # Do all this on CPU, this is memory exhaustive!
                    model.to(torch.device("cpu"))

                    if ema_active:
                        # backup current params and load ema params
                        backup_params = EMA.ema_backup_and_loadavg(ema, model)

                        torch.save(model,
                                   f"saved/checkpoint"
                                   f"_{str(self.__class__)}"
                                   f"_EM_{em:.2f}_F1_{f1:.2f}_L_{validation_loss:.2f}_{get_timestamp()}"
                                   f"_{socket.gethostname()}.pt")

                        # load back backed up params
                        EMA.ema_restore_backed_params(backup_params, model)

                    else:
                        torch.save(model,
                                   f"saved/checkpoint"
                                   f"_{str(self.__class__)}"
                                   f"_EM_{em:.2}_F1_{f1:.2}_L_{validation_loss:.2}_{get_timestamp()}"
                                   f"_{socket.gethostname()}.pt")

                    model.to(device)
                logging.info(f"Validation loss: {validation_loss}")

        except KeyboardInterrupt:
            logging.info('-' * 120)
            logging.info('Exit from training early.')
        finally:
            logging.info(f'Finished after {(time.time() - start_time) / 60} minutes.')