示例#1
0
    def get_dev_loss(self, model):
        logging.info("Calculating dev loss...")
        tic = time.time()
        loss_per_batch, batch_lengths = [], []
        i = 0
        for batch in get_batch_generator(self.word2id,
                                         self.dev_context_path,
                                         self.dev_qn_path,
                                         self.dev_ans_path,
                                         config.batch_size,
                                         context_len=config.context_len,
                                         question_len=config.question_len,
                                         discard_long=True):

            loss, _, _ = self.eval_one_batch(batch, model)
            curr_batch_size = batch.batch_size
            loss_per_batch.append(loss * curr_batch_size)
            batch_lengths.append(curr_batch_size)
            i += 1
            if i == 10:
                break
        total_num_examples = sum(batch_lengths)
        toc = time.time()
        print "Computed dev loss over %i examples in %.2f seconds" % (
            total_num_examples, toc - tic)

        dev_loss = sum(loss_per_batch) / float(total_num_examples)

        return dev_loss
示例#2
0
def train(context_path, qn_path, ans_path):
    """ Train the network """

    model = Decoder(emb_matrix, 2)
    # Select the parameters which require grad / backpropagation
    params = list(filter(lambda p: p.requires_grad, model.parameters()))
    optimizer = optim.SGD(params,
                          lr=config.learning_rate,
                          weight_decay=config.l2_norm)

    checkpoint_name = "checkpoint-Embed{}-ep{}-iter{}".format(
        config.embedding_dim, 2, 1000)
    checkpoint_name = os.path.join(config.experiments_root_dir,
                                   checkpoint_name)
    # If the network has saved model, restore it
    if os.path.exists(checkpoint_name):
        state = torch.load(checkpoint_name)
        model.load_state_dict(state['model'])
        optimizer.load_state_dict(state['optimizer'])
        start_epoch = state['epoch']
        i = state['iter']
        current_loss = state['loss']
        print("Model restored from ", checkpoint_name)
        print("Epoch : {}\tIter {}\t\tloss : {}".format(
            start_epoch, i, current_loss))
    else:
        print("Training with fresh parameters")

    # For each epoch
    for epoch in range(config.num_epochs):
        # For each batch
        for i, batch in enumerate(
                get_batch_generator(word2index,
                                    context_path,
                                    qn_path,
                                    ans_path,
                                    config.batch_size,
                                    config.context_len,
                                    config.question_len,
                                    discard_long=True)):

            # Take step in training
            loss = step(model, optimizer, batch)

            # Displaying results
            if i % config.print_every == 0:
                f1 = evaluate(model, batch)
                print("Epoch : {}\tIter {}\t\tloss : {}\tf1 : {}".format(
                    epoch, i, "%.2f" % loss, "%.2f" % f1))
                # Maybe you want to do random evaluations as well for sanity check

            # Saving the model
            if i % config.save_every == 0:
                state = {
                    'iter': i,
                    'epoch': epoch,
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'current_loss': loss
                }
                checkpoint_name = "checkpoint-Embed{}-ep{}-iter{}".format(
                    config.embedding_dim, epoch, i)
                checkpoint_name = os.path.join(config.experiments_root_dir,
                                               checkpoint_name)
                torch.save(state, checkpoint_name)
示例#3
0
    def train(self, model_file_path):
        train_dir = os.path.join(config.log_root,
                                 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)
        model_dir = os.path.join(train_dir, 'model')
        if not os.path.exists(model_dir):
            os.mkdir(model_dir)
        bestmodel_dir = os.path.join(train_dir, 'bestmodel')
        if not os.path.exists(bestmodel_dir):
            os.makedirs(bestmodel_dir)

        summary_writer = tf.summary.FileWriter(train_dir)

        with open(os.path.join(train_dir, "flags.json"), 'w') as fout:
            json.dump(vars(config), fout)

        model = self.get_model(model_file_path)
        params = list(filter(lambda p: p.requires_grad, model.parameters()))
        optimizer = Adam(params, lr=config.lr, amsgrad=True)

        num_params = sum(p.numel() for p in params)
        logging.info("Number of params: %d" % num_params)

        exp_loss, best_dev_f1, best_dev_em = None, None, None

        epoch = 0
        global_step = 0

        logging.info("Beginning training loop...")
        while config.num_epochs == 0 or epoch < config.num_epochs:
            epoch += 1
            epoch_tic = time.time()
            for batch in get_batch_generator(self.word2id,
                                             self.train_context_path,
                                             self.train_qn_path,
                                             self.train_ans_path,
                                             config.batch_size,
                                             context_len=config.context_len,
                                             question_len=config.question_len,
                                             discard_long=True):
                global_step += 1
                iter_tic = time.time()

                loss, param_norm, grad_norm = self.train_one_batch(
                    batch, model, optimizer, params)
                write_summary(loss, "train/loss", summary_writer, global_step)

                iter_toc = time.time()
                iter_time = iter_toc - iter_tic

                if not exp_loss:
                    exp_loss = loss
                else:
                    exp_loss = 0.99 * exp_loss + 0.01 * loss

                if global_step % config.print_every == 0:
                    logging.info(
                        'epoch %d, iter %d, loss %.5f, smoothed loss %.5f, grad norm %.5f, param norm %.5f, batch time %.3f'
                        % (epoch, global_step, loss, exp_loss, grad_norm,
                           param_norm, iter_time))

                if global_step % config.save_every == 0:
                    logging.info("Saving to %s..." % model_dir)
                    self.save_model(model, optimizer, loss, global_step, epoch,
                                    model_dir)

                if global_step % config.eval_every == 0:
                    dev_loss = self.get_dev_loss(model)
                    logging.info("Epoch %d, Iter %d, dev loss: %f" %
                                 (epoch, global_step, dev_loss))
                    write_summary(dev_loss, "dev/loss", summary_writer,
                                  global_step)

                    train_f1, train_em = self.check_f1_em(model,
                                                          "train",
                                                          num_samples=1000)
                    logging.info(
                        "Epoch %d, Iter %d, Train F1 score: %f, Train EM score: %f"
                        % (epoch, global_step, train_f1, train_em))
                    write_summary(train_f1, "train/F1", summary_writer,
                                  global_step)
                    write_summary(train_em, "train/EM", summary_writer,
                                  global_step)

                    dev_f1, dev_em = self.check_f1_em(model,
                                                      "dev",
                                                      num_samples=0)
                    logging.info(
                        "Epoch %d, Iter %d, Dev F1 score: %f, Dev EM score: %f"
                        % (epoch, global_step, dev_f1, dev_em))
                    write_summary(dev_f1, "dev/F1", summary_writer,
                                  global_step)
                    write_summary(dev_em, "dev/EM", summary_writer,
                                  global_step)

                    if best_dev_f1 is None or dev_f1 > best_dev_f1:
                        best_dev_f1 = dev_f1

                    if best_dev_em is None or dev_em > best_dev_em:
                        best_dev_em = dev_em
                        logging.info("Saving to %s..." % bestmodel_dir)
                        self.save_model(model, optimizer, loss, global_step,
                                        epoch, bestmodel_dir)

            epoch_toc = time.time()
            logging.info("End of epoch %i. Time for epoch: %f" %
                         (epoch, epoch_toc - epoch_tic))

        sys.stdout.flush()
示例#4
0
    def check_f1_em(self,
                    model,
                    dataset,
                    num_samples=100,
                    print_to_screen=False):
        logging.info(
            "Calculating F1/EM for %s examples in %s set..." %
            (str(num_samples) if num_samples != 0 else "all", dataset))

        if dataset == "train":
            context_path, qn_path, ans_path = self.train_context_path, self.train_qn_path, self.train_ans_path
        elif dataset == "dev":
            context_path, qn_path, ans_path = self.dev_context_path, self.dev_qn_path, self.dev_ans_path
        else:
            raise ('dataset is not defined')

        f1_total = 0.
        em_total = 0.
        example_num = 0

        tic = time.time()

        for batch in get_batch_generator(self.word2id,
                                         context_path,
                                         qn_path,
                                         ans_path,
                                         config.batch_size,
                                         context_len=config.context_len,
                                         question_len=config.question_len,
                                         discard_long=False):

            pred_start_pos, pred_end_pos = self.test_one_batch(batch, model)

            pred_start_pos = pred_start_pos.tolist()
            pred_end_pos = pred_end_pos.tolist()

            for ex_idx, (pred_ans_start, pred_ans_end, true_ans_tokens) \
                    in enumerate(zip(pred_start_pos, pred_end_pos, batch.ans_tokens)):
                example_num += 1
                pred_ans_tokens = batch.context_tokens[ex_idx][
                    pred_ans_start:pred_ans_end + 1]
                pred_answer = " ".join(pred_ans_tokens)

                true_answer = " ".join(true_ans_tokens)

                f1 = f1_score(pred_answer, true_answer)
                em = exact_match_score(pred_answer, true_answer)
                f1_total += f1
                em_total += em

                if print_to_screen:
                    print_example(self.word2id, batch.context_tokens[ex_idx],
                                  batch.qn_tokens[ex_idx],
                                  batch.ans_span[ex_idx,
                                                 0], batch.ans_span[ex_idx, 1],
                                  pred_ans_start, pred_ans_end, true_answer,
                                  pred_answer, f1, em)

                if num_samples != 0 and example_num >= num_samples:
                    break

            if num_samples != 0 and example_num >= num_samples:
                break

        f1_total /= example_num
        em_total /= example_num

        toc = time.time()
        logging.info(
            "Calculating F1/EM for %i examples in %s set took %.2f seconds" %
            (example_num, dataset, toc - tic))

        return f1_total, em_total
示例#5
0
def train(context_path, qn_path, ans_path):
    """ Train the network """
    model = N.CoattentionNetwork(
        device=config.device,
        hidden_size=config.hidden_size,
        emb_matrix=emb_matrix,
        num_encoder_layers=config.num_encoder_layers,
        num_fusion_bilstm_layers=config.num_fusion_bilstm_layers,
        num_decoder_layers=config.num_decoder_layers,
        batch_size=config.batch_size,
        max_dec_steps=config.max_dec_steps,
        fusion_dropout_rate=config.fusion_dropout_rate,
        encoder_bidirectional=config.encoder_bidirectional,
        decoder_bidirectional=config.decoder_bidirectional)

    # Select the parameters which require grad / backpropagation
    params = list(filter(lambda p: p.requires_grad, model.parameters()))
    optimizer = optim.SGD(params,
                          lr=config.learning_rate,
                          weight_decay=config.l2_norm)

    # Set up directories for this experiment
    if not os.path.exists(config.experiments_root_dir):
        os.makedirs(config.experiments_root_dir)

    serial_number = len(os.listdir(config.experiments_root_dir))
    if config.restore:
        serial_number -= 1  # Check into the latest model
    experiment_dir = os.path.join(config.experiments_root_dir,
                                  'experiment_{}'.format(serial_number))

    if not os.path.exists(experiment_dir):
        os.makedirs(experiment_dir)
    model_dir = os.path.join(experiment_dir, 'model')
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    # Save config as config.json
    with open(os.path.join(experiment_dir, "config.json"), 'w') as fout:
        json.dump(vars(config), fout)

    iteration = 0
    if config.restore:
        saved_models = os.listdir(model_dir)
        if len(saved_models):
            print(saved_models)
            saved_models = [int(name.split('-')[-1]) for name in saved_models]
            latest_iter = max(saved_models)
            checkpoint_name = "checkpoint-embed{}-iter-{}".format(
                config.embedding_dim, latest_iter)
            checkpoint_name = os.path.join(model_dir, checkpoint_name)

            state = torch.load(checkpoint_name)
            model.load_state_dict(state['model'])
            optimizer.load_state_dict(state['optimizer'])
            iteration = state['iter']
            print("Model restored from ", checkpoint_name)
        else:
            print("Training with fresh parameters")

    for epoch in range(config.num_epochs):
        for batch in get_batch_generator(word2index,
                                         context_path,
                                         qn_path,
                                         ans_path,
                                         config.batch_size,
                                         config.context_len,
                                         config.question_len,
                                         discard_long=True):

            # When the batch is partially filled, ignore it.
            if batch.batch_size < config.batch_size:
                del batch
                continue

            # Take step in training
            loss = step(model, optimizer, batch, params)

            # Displaying results
            if iteration % config.print_every == 0:
                print("Iter {}\t\tloss : {}\tf1 : {}".format(
                    iteration, "%.5f" % loss, "%.4f" % -1))

            if iteration % config.evaluate_every == 0:
                f1 = evaluate(model, batch)
                print("Iter {}\t\tloss : {}\tf1 : {}".format(
                    iteration, "%.5f" % loss, "%.4f" % f1))
                # Maybe you want to do random evaluations as well for sanity check

            # Saving the model
            if iteration % config.save_every == 0:
                state = {
                    'iter': iteration,
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'loss': loss
                }
                checkpoint_name = "checkpoint-embed{}-iter-{}".format(
                    config.embedding_dim, iteration)

                fname = os.path.join(model_dir, checkpoint_name)
                torch.save(state, fname)

            del loss
            iteration += 1