Пример #1
0
def sample_from_class(
    dataset: torch.utils.data.Dataset,
    n_samples: int,
) -> torch.utils.data.Dataset:
    """
    Stratified sampling. Create a dataset with ``n_samples`` numbers of each class
    from the original dataset.

    :param dataset: Original dataset.
    :param n_samples: Number of samples to use from each class.
    """
    class_counts = {}
    new_data = []
    new_labels = []

    for index, data in enumerate(dataset.data):
        label = dataset.targets[index]
        c = label.item()
        class_counts[c] = class_counts.get(c, 0) + 1
        if class_counts[c] <= n_samples:
            new_data.append(torch.unsqueeze(data, 0))
            new_labels.append(torch.unsqueeze(label, 0))

    dataset.data = torch.cat(new_data)
    dataset.targets = torch.cat(new_labels)

    return dataset
Пример #2
0
def _train_val_split(
        rnd: np.random.RandomState,
        train_dataset: torch.utils.data.Dataset,
        validation_ratio=0.05
) -> tuple:
    """
    Apply sklearn's `train_val_split` function to PyTorch's dataset instance.

    :param rnd: `np.random.RandomState` instance.
    :param train_dataset: Training set. This is an instance of PyTorch's dataset.
    :param validation_ratio: The ratio of validation data.

    :return: Tuple of training set and validation set.
    """

    x_train, x_val, y_train, y_val = sk_train_val_split(
        train_dataset.data, train_dataset.targets, test_size=validation_ratio,
        random_state=rnd, stratify=train_dataset.targets
    )

    val_dataset = copy.deepcopy(train_dataset)

    train_dataset.data = x_train
    train_dataset.targets = y_train

    val_dataset.data = x_val
    val_dataset.targets = y_val

    return train_dataset, val_dataset
Пример #3
0
def replace_indexes(dataset: torch.utils.data.Dataset,
                    indexes: Union[List[int], np.ndarray],
                    seed=0,
                    only_mark: bool = False):
    if not only_mark:
        rng = np.random.RandomState(seed)
        new_indexes = rng.choice(list(set(range(len(dataset))) - set(indexes)),
                                 size=len(indexes))
        dataset.data[indexes] = dataset.data[new_indexes]
        dataset.targets[indexes] = dataset.targets[new_indexes]
    else:
        # Notice the -1 to make class 0 work
        dataset.targets[indexes] = -dataset.targets[indexes] - 1
Пример #4
0
def set_up_model(model: torch.nn.Module, ds: torch.utils.data.Dataset,
                 kappa: float, g: float, batch_size: int):
    """
    performs forward and backward pass
    thereby all layers and nodes are assign a relevance internally
    """
    # forward pass
    field = ds.get_parameter_batch(kappa, g, batch_size=batch_size)
    _ = np.squeeze(model(torch.tensor(field))).detach().numpy()
    # backward pass
    relevance = ds.parameter_to_bins(kappa, g)
    _ = model.relprop(
        np.tile(relevance[np.newaxis],
                [batch_size, 1, 1]).reshape(batch_size, -1))
def subsample_dataset(dataset: torch.utils.data.Dataset,
                      num_samples,
                      class_weights: dict,
                      copy_dataset=True):
    if copy_dataset: dataset = copy.deepcopy(dataset)
    weight_sum = sum(class_weights.values())
    num_local = math.ceil(class_weights[1] / weight_sum * num_samples)
    num_noise = math.ceil(class_weights[0] / weight_sum * num_samples)
    dataset.file_paths = dataset.local[:num_local] + dataset.noise[:num_noise]
    return dataset
Пример #6
0
def calculate_observables(
        ds: torch.utils.data.Dataset) -> Tuple[np.ndarray, np.ndarray]:
    magnetization = np.zeros((len(ds.g), len(ds.kappa)))
    staggered_magnetization = np.zeros((len(ds.g), len(ds.kappa)))
    for i, g in enumerate(ds.g):
        for j, kappa in enumerate(ds.kappa):
            batch = ds.get_parameter_batch(kappa, g, batch_size=5)
            magnetization[i, j] = np.mean(batch)
            staggered_magnetization[i, j] = np.mean(
                get_staggered_magnetization(batch))
    return magnetization, staggered_magnetization
Пример #7
0
def subsample_dataset(dataset: torch.utils.data.Dataset,
                      num_samples,
                      class_weights: dict,
                      random_shuffle=False,
                      copy_dataset=True):
    if copy_dataset: dataset = copy.deepcopy(dataset)
    weight_sum = sum(class_weights.values())
    num_local = math.ceil(class_weights[1] / weight_sum * num_samples)
    num_noise = math.ceil(class_weights[0] / weight_sum * num_samples)

    if random_shuffle:
        dataset.file_paths = random.sample(
            dataset.local, num_local) + random.sample(dataset.noise, num_noise)
    else:
        dataset.file_paths = dataset.local[:
                                           num_local] + dataset.noise[:
                                                                      num_noise]

    dataset.shuffle()

    return dataset
Пример #8
0
def measure_performance(model: nn.Module, ds: torch.utils.data.Dataset) -> Tuple[np.arange, np.array]:
    k_mean, g_mean = [], []
    for kappa, g in ds.labels:
        batch = ds.get_parameter_batch(kappa, g, batch_size=5)
        output = model(torch.tensor(batch))
        if output.shape[-1] > 2:        # binned regression
            k_prediction, g_prediction = ds.bins_to_parameter(output.detach().numpy())
        else:
            o = output.detach().numpy()
            k_prediction, g_prediction = o[:, 0], o[:, 1]
        k_mean.append(k_prediction.mean())
        g_mean.append(g_prediction.mean())

    k_mean = np.array(k_mean).reshape(len(ds.g), len(ds.kappa))
    delta_kappa = np.round(ds.kappa[0] - ds.kappa[1], 8)
    error_kappa = np.round((k_mean - ds.kappa) / delta_kappa, 8)

    g_mean = np.array(g_mean).reshape(len(ds.g), len(ds.kappa))
    delta_g = np.round(ds.g[0] - ds.g[1], 8)
    error_g = np.round((g_mean.T - ds.g) / delta_g, 8)
    return error_kappa, error_g.T
Пример #9
0
def run_nn_model(model_path, test_dataset: torch.utils.data.Dataset,
                 experiment_name: str) -> dict:
    """
    Apply a neural network model to a test dataset
    ------
    @param model_path: The path where a trained model can be found
    @param test_dataset: The test data
    @param experiment_name: The folder name of the current experiment
    ------
    @return predictions for the test data
    """
    # Test dataset
    if (type(test_dataset) == MusicDataset):
        test_files, X_test, _ = test_dataset.get_whole_dataset_labels_zero_based(
        )
    else:
        test_files, X_test, _ = test_dataset.get_whole_dataset()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Apply model
    model = torch.load(model_path, map_location=device)

    # Predict
    print('Predicting classes...')
    res_all = []
    for X in tqdm(X_test, position=0, leave=True):
        X = torch.tensor(X).unsqueeze(0).float().to(device)
        result = model(X)
        result = np.argmax(result.detach().cpu().numpy(), axis=1)
        res_all.append(result.item())
    predictions = {}
    for i, file_id in enumerate(test_files):
        predictions[
            file_id] = res_all[i] + 1  # Labels not zero_based in the nn model

    return predictions
Пример #10
0
def random_dataset_split(dataset: torch.utils.data.Dataset,
                         split_sizes: tuple = (3 / 5., 1 / 5., 1 / 5.),
                         rnd_seed: int = 123,
                         verbose: bool = True):
    """ Split a torch.utils.data.Dataset into multiple random splits, each provided via a dedicated
     torch.utils.data.Dataset instance.

     Randomly shuffly indices of dataset, then split dataset into splits according to split_sizes, and return a new
     torch.utils.data.Dataset instance for each split.

    Parameters
    ----------
    dataset : torch.utils.data.Dataset
        Dataset to split
    split_sizes : tuple
        Split sizes as fractions of the original dataset size (sum of split_sizes has to sum up to 1).
    rnd_seed : int
        Seed for random generator
    verbose : bool
        Verbose printing

    Returns
    ----------
    dataset_splits: list of DatasetSubset
        list of dataset splits as torch.utils.data.Dataset instances
    """
    original_dataset_len = dataset.__len__()
    original_indices = np.arange(original_dataset_len)
    rnd_gen = np.random.RandomState(rnd_seed)

    rnd_gen.shuffle(original_indices)

    split_inds_arrays = [
        original_indices[int(original_dataset_len * sum(split_sizes[:split_i])
                             ):int(original_dataset_len *
                                   sum(split_sizes[:split_i + 1]))]
        for split_i in range(len(split_sizes))
    ]
    if verbose:
        print(f"Split dataset into splits {split_sizes} "
              f"with {[len(s) for s in split_inds_arrays]} samples per split")

    split_datasets = [
        DatasetSubset(dataset, split_inds) for split_inds in split_inds_arrays
    ]

    return split_datasets
Пример #11
0
    def validate_on(self, set: torch.utils.data.Dataset,
                    loader: torch.utils.data.DataLoader) -> Tuple[Any, float]:
        self.model.eval()

        with torch.no_grad():
            loss_sum = 0

            test = set.start_test()
            for d in tqdm(loader):
                d = self.helper.to_device(d)
                res = self.model_interface(d)
                digits = self.model_interface.decode_outputs(res)
                loss_sum += res.loss.item() * res.batch_size

                test.step(digits, d)

        self.model.train()
        return test, loss_sum / len(set)
Пример #12
0
    def run_active_weasul(self,
                          label_matrix: np.ndarray,
                          y_train: np.ndarray,
                          cliques: list,
                          class_balance: np.ndarray,
                          label_matrix_test: np.ndarray = None,
                          y_test: np.ndarray = None,
                          train_dataset: torch.utils.data.Dataset = None,
                          test_dataset: torch.utils.data.Dataset = None):
        """Iteratively label points, refit label model and return adjusted probabilistic labels.

        Args:
            label_matrix (numpy.array): Array with labeling function outputs on train set
            y_train (numpy.array): Ground truth labels of training dataset
            cliques (list): List of lists of maximal cliques (column indices of label matrix)
            class_balance (numpy.array): Array with true class distribution
            label_matrix_test (numpy.array, optional): Array with labeling function outputs on test set
            y_test (numpy.array, optional): Ground truth labels of test set
            train_dataset (torch.utils.data.Dataset, optional): Train dataset if training
                discriminative model on image data. Should be
                custom dataset with attribute Y containing target labels.
            test_dataset (torch.utils.data.Dataset, optional): Test dataset if training
                discriminative model on image data

        Returns:
            torch.Tensor: Tensor with probabilistic labels for training dataset
        """
        if any(v is None for v in (label_matrix_test, y_test, test_dataset)):
            label_matrix_test = label_matrix.copy()
            y_test = y_train.copy()
            test_dataset = train_dataset

        self.label_matrix = label_matrix.copy()
        self.label_matrix_test = label_matrix_test.copy()
        self.y_train = y_train.copy()
        self.y_test = y_test.copy()

        self.ground_truth_labels = np.full_like(y_train, -1)

        dl_test = DataLoader(test_dataset,
                             shuffle=False,
                             batch_size=self.batch_size)

        if self.discriminative_model is not None and self.discriminative_model.early_stopping:
            # Split into train and validation sets for early stopping
            indices_shuffle = np.random.permutation(len(self.label_matrix))
            split_nr = int(np.ceil(0.9 * len(self.label_matrix)))
            self.train_idx, val_idx = indices_shuffle[:
                                                      split_nr], indices_shuffle[
                                                          split_nr:]
        else:
            self.train_idx = range(len(self.y_train))

        # Identify buckets
        self.unique_combs, self.unique_idx, self.unique_inverse = np.unique(
            label_matrix, return_index=True, return_inverse=True, axis=0)
        self.bucket_conf_dict = {
            range(len(self.unique_idx))[i]: "-".join([str(e) for e in row])
            for i, row in enumerate(self.label_matrix[self.unique_idx, :])
        }

        for i in range(self.it + 1):

            # Fit label model and predict to obtain probabilistic labels
            prob_labels_train = self.label_model.fit(
                label_matrix=self.label_matrix,
                cliques=cliques,
                class_balance=class_balance,
                ground_truth_labels=self.ground_truth_labels).predict()
            prob_labels_test = self.label_model.predict(
                self.label_matrix_test, self.label_model.mu,
                self.label_model.E_S)

            # Optionally, train discriminative model on probabilistic labels
            if self.discriminative_model is not None and i % self.discr_model_frequency == 0:
                discriminative_model_probs_train = prob_labels_train.clone(
                ).detach()
                # Replace probabilistic labels with ground truth for labelled points
                discriminative_model_probs_train[self.ground_truth_labels ==
                                                 1, :] = (torch.DoubleTensor(
                                                     [0, 1]))
                discriminative_model_probs_train[self.ground_truth_labels ==
                                                 0, :] = (torch.DoubleTensor(
                                                     [1, 0]))
                train_dataset.Y = discriminative_model_probs_train

                if i > 0:
                    # Reset discriminative model parameters to train with updated labels
                    self.discriminative_model.reset()
                dl_train = DataLoader(
                    CustomTensorDataset(*train_dataset[self.train_idx]),
                    shuffle=True,
                    batch_size=self.batch_size)
                if self.discriminative_model.early_stopping:
                    dl_val = DataLoader(
                        CustomTensorDataset(*train_dataset[val_idx]),
                        shuffle=True,
                        batch_size=self.batch_size)
                else:
                    dl_val = None
                preds_train = self.discriminative_model.fit(dl_train,
                                                            dl_val).predict()
                preds_test = self.discriminative_model.predict(dl_test)
            else:
                preds_train = None
                preds_test = None

            if i == 0:
                sel_idx = None
                # Different seed for rest of the pipeline after first label model fit
                set_seed(self.seed)

                # Switch to active learning mode
                self.label_model.active_learning = True
                self.label_model.penalty_strength = self.penalty_strength

            self.log(count=i,
                     lm_train=prob_labels_train,
                     lm_test=prob_labels_test,
                     fm_train=preds_train,
                     fm_test=preds_test,
                     selected_point=sel_idx)

            if i < self.it:
                # Query point and add to ground truth labels
                sel_idx = self.sample(prob_labels_train)
                self.ground_truth_labels[sel_idx] = self.y_train[sel_idx]

                if self.query_strategy == "nashaat":
                    self.label_model.active_learning = False

                    # Nashaat et al. replace labeling function outputs by ground truth
                    self.label_matrix[sel_idx, :] = self.y_train[sel_idx]

        return prob_labels_train
Пример #13
0
def train_CAV(concepts: List[Concept],
              activations: Union[Tuple, np.ndarray],
              dataset: torch.utils.data.Dataset,
              train: List[List[int]],
              val: List[List[int]],
              gpu: bool,
              epochs: int,
              batch_size: int,
              nw: int,
              verbose: bool = False,
              queue: Queue = None) -> Dict:

    # Eventually reload the memmap
    if isinstance(activations, tuple):
        activations = np.memmap(activations[0],
                                dtype=float,
                                mode='r',
                                shape=activations[1])

    result = []
    for concept, train_idx, val_idx in zip(concepts, train, val):
        # Split dataset
        dataset.target_concept = concept
        dataset.return_index = True
        dataset.skip_image = True

        train_set = torch.utils.data.Subset(dataset, train_idx)
        val_set = torch.utils.data.Subset(dataset, val_idx)

        # Train CAV
        cav = CAV(concept,
                  train_set,
                  activations,
                  batch_size=batch_size,
                  epochs=epochs,
                  verbose=verbose,
                  gpu=gpu,
                  nw=0,
                  criterion=FocalLoss)

        # Evaluate CAV (Train)
        train_stat = eval_CAV(cav,
                              concept,
                              train_set,
                              activations,
                              batch_size=batch_size,
                              verbose=verbose,
                              gpu=gpu,
                              nw=0)

        # Evaluate CAV (Val)
        val_stat = eval_CAV(cav,
                            concept,
                            val_set,
                            activations,
                            batch_size=batch_size,
                            verbose=verbose,
                            gpu=gpu,
                            nw=0)

        # Create concept entry
        result.append({
            'train_set': train_idx,
            'val_set': val_idx,
            'train_stat': train_stat,
            'val_stat': val_stat,
            'cav': np.array(cav.weight.data.detach().cpu().numpy())
        })

        # Eventually notify progress
        if queue is not None:
            queue.put(1)

        if verbose:
            print('== Train')
            print(result['train_stat'])
            print('== Val')
            print(result['val_stat'])

    return result
Пример #14
0
def record_activations(model: torch.nn.Module,
                       modules_ids: List[ModuleID],
                       dataset: torch.utils.data.Dataset,
                       batch_size: int = 128,
                       cache: str = None,
                       gpu: bool = False,
                       silent: bool = False) -> Dict[ModuleID, np.ndarray]:
    """
    Parameters
    ----------
    model: torch.nn.Module
        PyTorch model to analyze
    modules_id: List[ModuleID]
        Coordinates of the modules to analyze
    dataset: torch.utils.data.Dataset
        Dataset containing the inputs
    batch_size: int, optional
        Batch size for the forward pass
    cache: str, optional
        Path of the folder in which to
        eventually store the activations
    gpu: bool, optional
        Flag to handle GPU usage
    silent: bool, optional
        Disables the progress bar

    Returns
    -------
    activations: Dict[ModuleID, np.ndarray]
        Dictionary mapping the module
        id to either a NumPy array
        or a memmap containing the
        activations per input and
        per unit
    """

    activations = {}
    act_size = {}

    # normalize module ids
    modules_ids = [(m, 0) if isinstance(m, str) else m for m in modules_ids]

    # module ids to string
    modules_str = [moduleid_to_string(m) for m in modules_ids]

    # eventually load from file
    if cache:
        # skip network forward pass
        skip = True

        # shape filenames
        shape_filenames = {
            m_id: os.path.join(cache, "size_%s.npy" % m_str)
            for m_id, m_str in zip(modules_ids, modules_str)
        }

        # activations filenames
        act_filenames = {
            m_id: os.path.join(cache, "act_%s.mmap" % m_str)
            for m_id, m_str in zip(modules_ids, modules_str)
        }

        # load from file
        for m_id in modules_ids:
            s_fn = shape_filenames[m_id]
            a_fn = act_filenames[m_id]
            if os.path.exists(s_fn) and os.path.exists(a_fn):
                act_size[m_id] = np.load(s_fn)
                activations[m_id] = np.memmap(a_fn,
                                              dtype=float,
                                              mode='r',
                                              shape=tuple(act_size[m_id]))
            else:
                skip = False

        # All the activations are on disk
        if skip:
            return activations

    # disable concept masks retrieval
    was_skipping_masks = dataset.skip_masks
    dataset.skip_masks = True

    # fix batch size for the image loader
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batch_size,
                                         collate_fn=collate_masks)

    # retrieve modules from the model
    modules_names = list({m_name for m_name, m_idx in modules_ids})
    modules = [get_module(model, m_name) for m_name in modules_names]

    # init recording variables
    recording = {m_name: [] for m_name, m_idx in modules_ids}

    # Define generic hook
    def hook_feature(module, input, output, module_name):
        recording[module_name].append(output.detach().cpu().numpy())

    # Register hooks
    hooks = [
        partial(hook_feature, module_name=m_name) for m_name in modules_names
    ]
    hooks = [
        module.register_forward_hook(h) for module, h in zip(modules, hooks)
    ]

    # keep track of the model status
    was_training = model.training
    model.eval()

    # batch iteration over the inputs
    first_batch = True
    for batch_idx, batch in enumerate(
            tqdm(loader, total=len(loader), disable=silent)):

        # Eventually ignore masks
        if isinstance(batch, tuple):
            batch = batch[0]

        # Delete previous recording
        keys = list(recording.keys())
        for key in keys:
            del recording[key][:]

        # Prepare input batch
        if gpu:
            batch = batch.cuda()

        # Forward pass of the input
        with torch.no_grad():
            _ = model.forward(batch)

        # initialize tensors
        if first_batch:
            for m_id in modules_ids:
                m_name, m_idx = m_id
                act_size[m_id] = (len(dataset),
                                  *recording[m_name][m_idx].shape[1:])
                if cache:
                    s_fn = shape_filenames[m_id]
                    a_fn = act_filenames[m_id]
                    np.save(s_fn, act_size[m_id])
                    activations[m_id] = np.memmap(a_fn,
                                                  dtype=float,
                                                  mode='w+',
                                                  shape=act_size[m_id])
                else:
                    activations[m_id] = np.zeros(act_size[m_id])

            # Do not repeat the initialization
            first_batch = False

        # copy activations
        start_idx = batch_idx * loader.batch_size
        end_idx = min((batch_idx + 1) * loader.batch_size, len(dataset))
        for m_id in modules_ids:
            m_name, m_idx = m_id
            activations[m_id][start_idx:end_idx] = recording[m_name][m_idx]

    # revert model status
    if was_training:
        model.train()

    # revert dataset preference
    dataset.skip_masks = was_skipping_masks

    # Remove hooks
    for hook in hooks:
        hook.remove()

    return activations
Пример #15
0
def _compute_sigma(activations: Union[Tuple[str, Tuple], np.ndarray],
                   dataset: torch.utils.data.Dataset, directions: np.ndarray,
                   concepts: List[Concept], thresholds: np.ndarray,
                   queue: Queue, batch_size: int, start: int,
                   end: int) -> np.ndarray:

    # Eventually reload the memmap
    if isinstance(activations, tuple):
        activations = np.memmap(activations[0],
                                dtype=float,
                                mode='r',
                                shape=activations[1])

    # Without custom directions adopt canonical basis
    # TODO: keep directions = None and exploit this later
    if not directions:
        directions = np.eye(activations.shape[1])

    # Number of directions and concepts to align
    n_directions = directions.shape[0]
    n_concepts = len(concepts)

    # Check number of directions and thresholds
    if n_directions != thresholds.shape[0]:
        raise ValueError('Number of directions and thresholds do not match')

    # Eventually allocate arrays
    intersection = np.zeros((n_directions, n_concepts))
    act_sum = np.zeros((n_directions))
    cmask_sum = np.zeros((n_concepts))

    # Ignore images
    dataset.skip_image = True

    # Create subset
    dataset = torch.utils.data.Subset(dataset, range(start, end))

    # Init loader
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batch_size,
                                         collate_fn=collate_masks)

    # Show progress bar when not multiprocessing
    if queue is None:
        loader = tqdm(loader)

    first = True
    for batch in loader:

        # Ignore images
        if isinstance(batch, tuple):
            batch = batch[1]

        if first:
            a_mask = np.full((n_directions, *batch[0].shape), False)
            c_mask = np.full(batch[0].shape, False)
            first = False

        for image in batch:

            # Precompute directional activations
            dir_act = [(activations[image.index].T @ directions[i]).T
                       for i in range(n_directions)]

            # Only keep valid directions
            # NOTE: is any faster than max?
            valid_dirs = [
                d_idx for d_idx in range(n_directions)
                if np.any(dir_act[d_idx] > thresholds[d_idx])
            ]

            # Generate activation masks
            for d_idx in valid_dirs:
                # Retrieve directional activations
                tmp_a_mask = dir_act[d_idx]

                # Resize if convolutional
                if len(tmp_a_mask.shape):
                    tmp_a_mask = Image.fromarray(tmp_a_mask) \
                        .resize(image.shape,
                                resample=Image.BILINEAR)
                # Create mask
                a_mask[d_idx] = tmp_a_mask > thresholds[d_idx]

                # Update \sum_x |M_u(x)|
                act_sum[d_idx] += np.count_nonzero(a_mask[d_idx])

            # Retrieve concepts in the image
            selected_concepts = image.select_concepts(concepts)

            for c_idx, concept in enumerate(concepts):

                if concept in selected_concepts:
                    # retrieve L_c(x)
                    c_mask = image.get_concept_mask(concept, c_mask)

                    # update \sum_x |L_c(x)|
                    cmask_sum[c_idx] += np.count_nonzero(c_mask)

                    # Update counters
                    for d_idx in valid_dirs:

                        # |M_u(x) && L_c(x)|
                        intersection[d_idx, c_idx] += np.count_nonzero(
                            np.logical_and(a_mask[d_idx], c_mask))

        # Notify end of batch
        if queue:
            queue.put(1)

    if queue:
        queue.put(None)

    # |M_u(x) || L_c(x)|
    union = act_sum[:, None] + cmask_sum[None, :] - intersection

    return intersection, union, act_sum, cmask_sum
def greedy_search(encoder: EncoderBILSTM, decoder: DecoderLSTM,
                  dataset: torch.utils.data.Dataset, use_cuda: bool,
                  batch_size: int) -> (list, list, list):
    q_idx_to_word = dataset.get_question_idx_to_word()
    a_idx_to_word = dataset.get_answer_idx_to_word()
    data_loader = DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             collate_fn=collate_fn,
                             pin_memory=True)
    given_sentence = []
    ground_truth = []
    final_prediction = []
    encoder.eval()
    decoder.eval()
    max_len = 30
    #batch = iter(data_loader).next()

    for cnt, batch in enumerate(data_loader):
        print(cnt)
        if cnt > 200:
            break
        questions, questions_org_len, answers, answers_org_len, pID = batch

        if use_cuda:
            questions = questions.cuda()
            answers = answers.cuda()
            answers_org_len = torch.FloatTensor(np.asarray(answers_org_len))

        attn = torch.zeros(max_len, answers.shape[1])

        encoder_input, encoder_len = answers, np.asarray(answers_org_len)

        if use_cuda:
            encoder_len = torch.LongTensor(encoder_len).cuda()
            decoder_inp = torch.ones((batch_size, 1), dtype=torch.long).cuda()
        else:
            encoder_len = torch.LongTensor(encoder_len)
            decoder_inp = torch.ones((batch_size, 1), dtype=torch.long)
        encoder_out, encoder_hidden = encoder(encoder_input, encoder_len)
        decoder_hidden = encoder_hidden
        # input to the first time step of decoder is <SOS> token.

        seq_len = 0
        eval_mode = False
        predicted_sequences = []
        while seq_len < max_len:
            seq_len += 1
            decoder_out, decoder_hidden, attn_scores = decoder(
                decoder_inp,
                decoder_hidden,
                encoder_out,
                answers_org_len,
                eval_mode=eval_mode)

            #attn[seq_len - 1, :] += attn_scores.squeeze().cpu().data
            # obtaining log_softmax scores we need to minimize log softmax over a span.
            decoder_out = decoder_out.view(batch_size, -1)
            decoder_out = torch.nn.functional.log_softmax(decoder_out, )
            prediction = torch.argmax(decoder_out, 1).unsqueeze(1)
            predicted_sequences.append(prediction)
            decoder_inp = prediction.clone()
            eval_mode = True

        given_sentence.extend([[
            a_idx_to_word[str(answers[i][j].item())]
            for j in range(len(answers[i])) if answers[i][j] != 0
        ] for i in range(len(answers))])
        ground_truth.extend([[
            q_idx_to_word[str(questions[i][j].item())]
            for j in range(len(questions[i])) if questions[i][j] != 0
        ] for i in range(len(questions))])
        prediction = []
        for i in range(batch_size):
            prediction.append([])
            for j in range(len(predicted_sequences)):
                if q_idx_to_word[str(
                        predicted_sequences[j][i][0].item())] == END_TOKEN:
                    prediction[i].append(END_TOKEN)
                    break
                prediction[i].append(q_idx_to_word[str(
                    predicted_sequences[j][i][0].item())])
        #show_attention(given_sentence,prediction,attn)
        final_prediction.extend(prediction)

    cnt = 0
    for sent, gt, pred in zip(given_sentence, ground_truth, final_prediction):
        if cnt < 1000:
            cnt += 1
            print("Sentence: %s \nGT Q: %s \nPred Q: %s" % (sent, gt, pred))
        else:
            break
    return [given_sentence], ground_truth, final_prediction