def _validate(self, predictor, dataset):
        # create a batch iterator for the given data
        batch_iterator = batchIterator(self.voc,
                                       dataset,
                                       self.options['batch_size'],
                                       shuffle=False)
        # find out how many iterations we will need to cover the whole dataset
        n_iters = len(dataset) // self.options['batch_size'] + int(
            len(dataset) % self.options['batch_size'] > 0)
        # containers for full prediction results so we can compute accuracy at the end
        all_preds = []
        all_labels = []
        for iteration in range(1, n_iters + 1):
            batch, batch_dialogs, true_batch_size = next(batch_iterator)
            # Extract fields from batch
            input_variable, dialog_lengths, utt_lengths, batch_indices, dialog_indices, \
                batch_labels, batch_ids, target_variable, mask, max_target_len = batch
            dialog_lengths_list = [len(x) for x in batch_dialogs]
            # run the model
            predictions, scores = self._evaluate_batch(
                predictor, input_variable, dialog_lengths, dialog_lengths_list,
                utt_lengths, batch_indices, dialog_indices, true_batch_size)
            # aggregate results for computing accuracy at the end
            all_preds += [p.item() for p in predictions]
            all_labels += [l.item() for l in batch_labels]
            print("Iteration: {}; Percent complete: {:.1f}%".format(
                iteration, iteration / n_iters * 100))

        # compute and return the accuracy
        return (np.asarray(all_preds) == np.asarray(all_labels)).mean()
    def _evaluate_dataset(self, predictor, dataset):
        """
        Run a trained CRAFT model over an entire dataset in a batched fashion.

        :param predictor: the trained CRAFT model to use, provided as a PyTorch Model instance.
        :param dataset: the dataset to evaluate on, formatted as a list of (context, reply, id_of_reply) tuples.
        :return: a DataFrame, indexed by utterance ID, of CRAFT scores for each utterance, and the corresponding binary prediction.
        """
        # create a batch iterator for the given data
        batch_iterator = batchIterator(self.voc,
                                       dataset,
                                       self.options['batch_size'],
                                       shuffle=False)
        # find out how many iterations we will need to cover the whole dataset
        n_iters = len(dataset) // self.options['batch_size'] + int(
            len(dataset) % self.options['batch_size'] > 0)
        output_df = {
            "id": [],
            self.forecast_feat_name: [],
            self.forecast_prob_feat_name: []
        }
        for iteration in range(1, n_iters + 1):
            batch, batch_dialogs, true_batch_size = next(batch_iterator)
            # Extract fields from batch
            input_variable, dialog_lengths, utt_lengths, batch_indices, dialog_indices, \
                labels, batch_ids, target_variable, mask, max_target_len = batch
            dialog_lengths_list = [len(x) for x in batch_dialogs]
            # run the model
            predictions, scores = self._evaluate_batch(
                predictor, input_variable, dialog_lengths, dialog_lengths_list,
                utt_lengths, batch_indices, dialog_indices, true_batch_size)

            # format the output as a dataframe (which we can later re-join with the corpus)
            for i in range(true_batch_size):
                utt_id = batch_ids[i]
                pred = predictions[i].item()
                score = scores[i].item()
                output_df["id"].append(utt_id)
                output_df[self.forecast_feat_name].append(pred)
                output_df[self.forecast_prob_feat_name].append(score)

            print("Iteration: {}; Percent complete: {:.1f}%".format(
                iteration, iteration / n_iters * 100))

        return pd.DataFrame(output_df).set_index("id")
    def _train_iters(self, train_pairs, val_pairs, encoder, context_encoder,
                     attack_clf, encoder_optimizer, context_encoder_optimizer,
                     attack_clf_optimizer, embedding, n_iteration,
                     validate_every):

        # create a batch iterator for training data
        batch_iterator = batchIterator(self.voc, train_pairs,
                                       self.options['batch_size'])

        # Initializations
        print('Initializing ...')
        start_iteration = 1
        print_loss = 0

        # Training loop
        print("Training...")
        # keep track of best validation accuracy - only save when we have a model that beats the current best
        best_acc = 0
        for iteration in range(start_iteration, n_iteration + 1):
            training_batch, training_dialogs, true_batch_size = next(
                batch_iterator)
            # Extract fields from batch
            input_variable, dialog_lengths, utt_lengths, batch_indices, dialog_indices, \
            labels, batch_ids, target_variable, mask, max_target_len = training_batch
            dialog_lengths_list = [len(x) for x in training_dialogs]

            # Run a training iteration with batch
            loss = self._train_NN(
                input_variable,
                dialog_lengths,
                dialog_lengths_list,
                utt_lengths,
                batch_indices,
                dialog_indices,
                labels,  # input/output arguments
                encoder,
                context_encoder,
                attack_clf,  # network arguments
                encoder_optimizer,
                context_encoder_optimizer,
                attack_clf_optimizer,  # optimization arguments
                true_batch_size,
                self.options['clip'])  # misc arguments
            print_loss += loss

            # Print progress
            if iteration % self.options['print_every'] == 0:
                print_loss_avg = print_loss / self.options['print_every']
                print(
                    "Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}"
                    .format(iteration, iteration / n_iteration * 100,
                            print_loss_avg))
                print_loss = 0

            # Evaluate on validation set
            if iteration % validate_every == 0:
                print("Validating!")
                # put the network components into evaluation mode
                encoder.eval()
                context_encoder.eval()
                attack_clf.eval()

                predictor = Predictor(encoder, context_encoder, attack_clf)
                accuracy = self._validate(predictor, val_pairs)
                print("Validation set accuracy: {:.2f}%".format(accuracy *
                                                                100))

                # keep track of our best model so far
                if accuracy > best_acc:
                    print(
                        "Validation accuracy better than current best; saving model..."
                    )
                    best_acc = accuracy
                    torch.save(
                        {
                            'iteration': iteration,
                            'en': encoder.state_dict(),
                            'ctx': context_encoder.state_dict(),
                            'atk_clf': attack_clf.state_dict(),
                            'en_opt': encoder_optimizer.state_dict(),
                            'ctx_opt': context_encoder_optimizer.state_dict(),
                            'atk_clf_opt': attack_clf_optimizer.state_dict(),
                            'loss': loss,
                            'voc_dict': self.voc.__dict__,
                            'embedding': embedding.state_dict()
                        }, self.options['trained_model_output_filepath'])

                # put the network components back into training mode
                encoder.train()
                context_encoder.train()
                attack_clf.train()