def __run_epoch(self,
                    epoch_name: str,
                    data: Iterable[Any],
                    data_fold: DataFold,
                    quiet: Optional[bool] = False,
                    summary_writer: Optional[tf.summary.FileWriter] = None) \
            -> Tuple[float, List[Dict[str, Any]], int, float, float, float,float]:
        """
        Run one epoch of the neural network
        :param epoch_name:
        :param data:
        :param data_fold:
        :param quiet:
        :param summary_writer:
        :return:
        """
        batch_iterator = self.task.make_minibatch_iterator(
            data, data_fold, self.__placeholders, self.params['max_nodes_in_batch'])
        batch_iterator = ThreadedIterator(batch_iterator, max_queue_size=5)
        task_metric_results = []
        start_time = time.time()
        processed_graphs, processed_nodes, processed_edges = 0, 0, 0
        epoch_loss = 0.0
        final_representation = 0
        for step, batch_data in enumerate(batch_iterator):
            # assert step == 0 or not DataFold.TEST, step
            if data_fold == DataFold.TRAIN:
                batch_data.feed_dict[self.__placeholders['graph_layer_input_dropout_keep_prob']] = \
                    self.params['graph_layer_input_dropout_keep_prob']
            batch_data.feed_dict[self.__placeholders['num_graphs']] = batch_data.num_graphs

            # Collect some statistics:
            processed_graphs += batch_data.num_graphs
            processed_nodes += batch_data.num_nodes
            processed_edges += batch_data.num_edges

            fetch_dict = {'task_metrics': self.__ops['task_metrics'],'final_representation':self.__ops['final_representation']}
            if summary_writer:
                fetch_dict['tf_summaries'] = self.__ops['tf_summaries']
                fetch_dict['total_num_graphs'] = self.__ops['total_num_graphs']
            if data_fold == DataFold.TRAIN:
                fetch_dict['train_step'] = self.__ops['train_step']
            fetch_results = self.sess.run(fetch_dict, feed_dict=batch_data.feed_dict)
            epoch_loss += fetch_results['task_metrics']['loss'] * batch_data.num_graphs
            task_metric_results.append(fetch_results['task_metrics'])
            final_representation = fetch_results['final_representation']
            if not quiet:
                print("Running %s, batch %i (has %i graphs). Loss so far: %.4f"
                      % (epoch_name, step, batch_data.num_graphs, epoch_loss / processed_graphs),
                      end='\r')
            if summary_writer:
                summary_writer.add_summary(fetch_results['tf_summaries'], fetch_results['total_num_graphs'])

        assert processed_graphs > 0, "Can't run epoch over empty dataset."

        epoch_time = time.time() - start_time
        per_graph_loss = epoch_loss / processed_graphs

        return per_graph_loss, task_metric_results, processed_graphs,epoch_time, final_representation
Exemplo n.º 2
0
 def compute_metadata(self,
                      dataset_iterator: Iterator[TRawDatapoint],
                      parallelize: bool = True) -> None:
     """
     Compute the metadata for this model including its children.
     This function should be invoked by the root-level model.
     """
     assert not self.__metadata_initialized, "Metadata has already been initialized."
     self.__initialize_metadata_recursive()
     for element in ThreadedIterator(dataset_iterator, enabled=parallelize):
         self.update_metadata_from(element)
     self.__finalize_metadata_recursive()
Exemplo n.º 3
0
    def minibatch_iterator(
        self,
        tensorized_data: Iterator[Tuple[TTensorizedDatapoint,
                                        Optional[TRawDatapoint]]],
        device: Union[str, torch.device],
        max_minibatch_size: int,
        yield_partial_minibatches: bool = True,
        shuffle_input: bool = False,
        parallelize: bool = True,
    ) -> Iterator[Tuple[Dict[str, Any], List[Optional[TRawDatapoint]]]]:
        """
        An iterator that yields minibatches to be consumed by a neural module.
        :param tensorized_data: An iterator of tensorized data. Commonly that's the output of tensorize_dataset()
        :param device: The device on which the tensorized data will be stored.
        :param max_minibatch_size: the maximum size of the minibatch.
        :param yield_partial_minibatches: If true, yield partial minibatches, i.e. minibatches that do not
            reach the `max_minibatch_size` and the `extend_minibatch_with` did not consider full.
            Users might want to set this to False, when training.
        :param shuffle_input: Should the `tensorized_data` be shuffled? (e.g. during training)
        :param parallelize: if True, minibatching will be parallelized. This may make debugging harder.
        :return: an iterator that yield tuples, with the minibatch data and the raw data points (if they are present in `tensorized_data`).
        """
        assert self.__metadata_initialized, "Metadata has not been initialized."

        if shuffle_input:
            tensorized_data = shuffled_iterator(tensorized_data)

        unfinalized_minibatches = ThreadedIterator(
            self.__iterate_unfinalized_minibatches(tensorized_data,
                                                   max_minibatch_size,
                                                   yield_partial_minibatches),
            enabled=parallelize,
        )
        yield from ThreadedIterator(
            ((self.finalize_minibatch(d[0], device), d[1])
             for d in unfinalized_minibatches),
            enabled=parallelize,
        )
    def __run_epoch(self,
                    epoch_name: str,
                    data: Iterable[Any],
                    data_fold: DataFold,
                    quiet: Optional[bool] = False,
                    summary_writer: Optional[tf.summary.FileWriter] = None) \
            -> Tuple[float, List[Dict[str, Any]], int, float, float, float]:
        # todo: get reserved words
        unsplittable_keywords = get_language_keywords('csharp')
        ADVERSARIAL_DEPTH = 3
        START_ADVERSARY_ALPHABET = 2
        END_ADVERSARY_ALPHABET = 28
        TARGETED_ATTACK = True
        SELECTED_CANDIDATE_ID_TARGETED_ATTACK = 1  # 0 is the correct one
        logfile = open(
            "example_log_{}.txt".format(
                datetime.now().strftime("%d-%m-%Y_%H-%M-%S")), "w")

        def adverse_var(unique_label_to_adverse,
                        unique_label_to_adverse_grads):
            # adverse label
            # OPTION: replace with random char
            # unique_label_to_adverse = adversarial.adversary_by_prefix_random(unique_label_to_adverse, -1)
            # OPTION: replace constant amount of chars with argmax
            # unique_label_to_adverse = adversarial.adversary_by_prefix_rename(unique_label_to_adverse,
            #                                                                  unique_label_to_adverse_grads, -1)
            # OPTION: replace argmax id with argmax char
            # unique_label_to_adverse = adversarial.adversary_all_or_until_argmax_id(unique_label_to_adverse,
            # unique_label_to_adverse_grads)
            # OPTION: replace all19 with argmax char
            # unique_label_to_adverse = adversarial.adversary_all19_by_argmax(unique_label_to_adverse,
            # unique_label_to_adverse_grads)
            # OPTION: replace adversary_all_or_until_top_and_index i1c1
            # unique_label_to_adverse = adversarial.adversary_all_or_until_top_and_index(unique_label_to_adverse,
            # unique_label_to_adverse_grads,
            # index_place=1, char_place=1)
            return adversarial.adversary_all19_by_argmax(
                unique_label_to_adverse, unique_label_to_adverse_grads)

        # TODO: noamcode: test loop - make iterator
        batch_iterator = self.task.make_minibatch_iterator(
            data, data_fold, self.__placeholders,
            self.params['max_nodes_in_batch'])

        batch_iterator = ThreadedIterator(batch_iterator, max_queue_size=5)
        task_metric_results = []
        start_time = time.time()
        processed_graphs, processed_nodes, processed_edges = 0, 0, 0
        epoch_loss = 0.0
        correct_predictions = 0
        adversarial_predictions = 0

        for step, batch_data in enumerate(batch_iterator):
            if data_fold == DataFold.TRAIN:
                batch_data.feed_dict[self.__placeholders['graph_layer_input_dropout_keep_prob']] = \
                    self.params['graph_layer_input_dropout_keep_prob']
            batch_data.feed_dict[
                self.__placeholders['num_graphs']] = batch_data.num_graphs
            # Collect some statistics:
            processed_graphs += batch_data.num_graphs
            processed_nodes += batch_data.num_nodes
            processed_edges += batch_data.num_edges

            fetch_dict = {'task_metrics': self.__ops['task_metrics']}
            if summary_writer:
                fetch_dict['tf_summaries'] = self.__ops['tf_summaries']
                fetch_dict['total_num_graphs'] = self.__ops['total_num_graphs']
            if data_fold == DataFold.TRAIN:
                fetch_dict['train_step'] = self.__ops['train_step']
            fetch_results = self.sess.run(fetch_dict,
                                          feed_dict=batch_data.feed_dict)
            epoch_loss += fetch_results['task_metrics'][
                'loss'] * batch_data.num_graphs
            task_metric_results.append(fetch_results['task_metrics'])

            if not quiet:
                print(
                    "Running %s, batch %i (has %i graphs). Loss so far: %.4f. adversarial so far: %i/%i (%.4f)"
                    %
                    (epoch_name, step, batch_data.num_graphs,
                     epoch_loss / processed_graphs, adversarial_predictions,
                     correct_predictions, adversarial_predictions /
                     correct_predictions if correct_predictions > 0 else 0.0),
                    end='\r')
            if summary_writer:
                summary_writer.add_summary(fetch_results['tf_summaries'],
                                           fetch_results['total_num_graphs'])

            # todo: noamcode: adversarial process
            correct = fetch_results["task_metrics"][
                "num_correct_predictions"] == 1

            # fetch relevant data from batch
            unique_labels_as_characters = \
                batch_data.feed_dict[self.__placeholders['unique_labels_as_characters']]
            node_labels_to_unique_labels = \
                batch_data.feed_dict[self.__placeholders['node_labels_to_unique_labels']]
            candidate_node_ids = batch_data.feed_dict[
                self.__placeholders['candidate_node_ids']]
            candidate_node_ids_mask = batch_data.feed_dict[
                self.__placeholders['candidate_node_ids_mask']]
            node_labels = batch_data.debug_data["node_labels"][0]

            # preprocess for adversarial
            masked_candidate_node_ids = candidate_node_ids[0][
                candidate_node_ids_mask[0]]
            candidate_node_varnames = [
                node_labels[str(i)] for i in masked_candidate_node_ids
            ]
            variable_names_nodes = [
                int(id) for id, name in node_labels.items()
                if name not in unsplittable_keywords
                and name not in candidate_node_varnames and len(name) > 0
                and name[0].islower()
            ]
            variable_names_unique_labels_ids = np.unique(
                node_labels_to_unique_labels[variable_names_nodes])

            if correct and variable_names_unique_labels_ids.size > 0:
                correct_predictions += 1
            else:
                continue

            # adversarial steps
            TARGET_CANDIDATE_ID_TO_ADVERSE = 0  # 0 is the correct one

            if TARGETED_ATTACK:
                # replace between true label and adversarial label
                true_target_node = candidate_node_ids[0][0]
                candidate_node_ids[0][0] = candidate_node_ids[0][
                    SELECTED_CANDIDATE_ID_TARGETED_ATTACK]
                candidate_node_ids[0][
                    SELECTED_CANDIDATE_ID_TARGETED_ATTACK] = true_target_node

            # node_to_adverse_id = candidate_node_ids[0][TARGET_CANDIDATE_ID_TO_ADVERSE]
            # unique_label_to_adverse_id = node_labels_to_unique_labels[node_to_adverse_id]

            #### simple attack - compute gradients and change one only
            simple_attack_works = False
            # grads computation
            grads = self.sess.run(self.task.unique_labels_input_grads,
                                  feed_dict=batch_data.feed_dict)
            grads = grads[0] if not TARGETED_ATTACK else -grads[0]
            for node_label_id in variable_names_unique_labels_ids:

                # unique_label_to_adverse_id = variable_names_unique_labels_ids[0]
                unique_label_to_adverse_id = node_label_id
                unique_label_to_adverse = unique_labels_as_characters[
                    unique_label_to_adverse_id]

                # todo: backup
                old_label_ints = unique_label_to_adverse.copy()
                old_label = adversarial.construct_name_from_ints(
                    old_label_ints, self.task.index_to_alphabet)

                unique_label_to_adverse_grads = grads[
                    unique_label_to_adverse_id, :,
                    START_ADVERSARY_ALPHABET:END_ADVERSARY_ALPHABET]

                unique_label_to_adverse = adverse_var(
                    unique_label_to_adverse, unique_label_to_adverse_grads)

                fetch_results = self.sess.run(fetch_dict,
                                              feed_dict=batch_data.feed_dict)

                if (not TARGETED_ATTACK and fetch_results["task_metrics"]["num_correct_predictions"] == 0)\
                        or (TARGETED_ATTACK and fetch_results["task_metrics"]["num_correct_predictions"] == 1):
                    new_label = adversarial.construct_name_from_ints(
                        unique_label_to_adverse, self.task.index_to_alphabet)
                    adversarial_predictions += 1
                    simple_attack_works = True
                    logfile.write("filename: {}\n".format(
                        batch_data.debug_data["filename"][0]))
                    logfile.write("slot_token_idx: {}\n".format(
                        batch_data.debug_data["slot_token_idx"][0]))
                    logfile.write(
                        "candidates: {}\n".format(candidate_node_varnames))
                    logfile.write("mutation: {} -> {}\n".format(
                        old_label, new_label))
                    break

                # todo: restore
                np.copyto(unique_label_to_adverse, old_label_ints)

            if simple_attack_works:
                continue

            #### complex attack - compute gradients and change one by one
            complex_attack_works = False
            # make backup unique_labels_as_characters
            unique_labels_as_characters_backup = unique_labels_as_characters.copy(
            )

            # compute grads & variables
            adversarial_vars = []
            for _ in range(ADVERSARIAL_DEPTH):
                grads = self.sess.run(self.task.unique_labels_input_grads,
                                      feed_dict=batch_data.feed_dict)
                grads = grads[0] if not TARGETED_ATTACK else -grads[0]
                unique_label_to_adverse_grads = grads[:, :,
                                                      START_ADVERSARY_ALPHABET:
                                                      END_ADVERSARY_ALPHABET]
                adversarial_vars.append(
                    adversarial.adversary_all19_by_argmax_batch(
                        unique_label_to_adverse_grads))
                [
                    np.copyto(unique_labels_as_characters[adverse_index],
                              adversarial_vars[-1][adverse_index])
                    for adverse_index in variable_names_unique_labels_ids
                ]

            # restore unique_labels_as_characters
            np.copyto(unique_labels_as_characters,
                      unique_labels_as_characters_backup)

            var_rename_dict = {}
            for node_label_id in variable_names_unique_labels_ids:

                # unique_label_to_adverse_id = variable_names_unique_labels_ids[0]
                unique_label_to_adverse_id = node_label_id
                unique_label_to_adverse = unique_labels_as_characters[
                    unique_label_to_adverse_id]

                # todo: backup
                old_label_ints = unique_label_to_adverse.copy()
                old_label = adversarial.construct_name_from_ints(
                    old_label_ints, self.task.index_to_alphabet)

                # try adversarials
                for adversarial_var in adversarial_vars:
                    np.copyto(unique_label_to_adverse,
                              adversarial_var[unique_label_to_adverse_id])

                    fetch_results = self.sess.run(
                        fetch_dict, feed_dict=batch_data.feed_dict)
                    if (not TARGETED_ATTACK and fetch_results["task_metrics"]["num_correct_predictions"] == 0)\
                            or (TARGETED_ATTACK and fetch_results["task_metrics"]["num_correct_predictions"] == 1):
                        new_label = adversarial.construct_name_from_ints(
                            unique_label_to_adverse,
                            self.task.index_to_alphabet)
                        var_rename_dict[old_label] = new_label
                        adversarial_predictions += 1
                        logfile.write("filename: {}\n".format(
                            batch_data.debug_data["filename"][0]))
                        logfile.write("slot_token_idx: {}\n".format(
                            batch_data.debug_data["slot_token_idx"][0]))
                        logfile.write(
                            "candidates: {}\n".format(candidate_node_varnames))
                        logfile.write(
                            "mutation: {} \n".format(var_rename_dict))
                        complex_attack_works = True
                        break

                if complex_attack_works:
                    break
                else:  #if failed - add mutated var to dict and continue to next var
                    new_label = adversarial.construct_name_from_ints(
                        unique_label_to_adverse, self.task.index_to_alphabet)
                    var_rename_dict[old_label] = new_label

        assert processed_graphs > 0, "Can't run epoch over empty dataset."
        logfile.close()

        epoch_time = time.time() - start_time
        per_graph_loss = epoch_loss / processed_graphs
        graphs_per_sec = processed_graphs / epoch_time
        nodes_per_sec = processed_nodes / epoch_time
        edges_per_sec = processed_edges / epoch_time

        print("correct:{} adversarial:{} ({})".format(
            correct_predictions, adversarial_predictions,
            adversarial_predictions / correct_predictions))
        return per_graph_loss, task_metric_results, processed_graphs, graphs_per_sec, nodes_per_sec, edges_per_sec
Exemplo n.º 5
0
 def validation_tensors():
     yield from ThreadedIterator(
         original_iterator=data_to_tensor_iterator(validation_data),
         max_queue_size=10 * self.__minibatch_size)
Exemplo n.º 6
0
 def training_tensors():
     yield from ThreadedIterator(
         original_iterator=data_to_tensor_iterator(training_data),
         max_queue_size=10 * self.__minibatch_size)
Exemplo n.º 7
0
    def train(self,
              training_data: Iterable[InputData],
              validation_data: Iterable[InputData],
              show_progress_bar: bool = True,
              patience: int = 5,
              initialize_metadata: bool = True,
              exponential_running_average_factor: float = 0.97,
              get_parameters_to_freeze: Optional[Callable[[], Set]] = None,
              parallel_minibatch_creation: bool = False,
              device: Optional[Union[str, torch.device]] = None) -> None:
        """
        The training-validation loop for `BaseComponent`s.

        :param training_data: An iterable that each iteration yields the full training data.
        :param validation_data: An iterable that each iteration yields the full validation data.
        :param show_progress_bar: Show a progress bar
        :param patience: The number of iterations before early stopping kicks in.
        :param initialize_metadata: If true, initialize the metadata from the training_data. Otherwise,
            assume that the model that is being trained has its metadata already initialized.
        :param exponential_running_average_factor: The factor of the running average of the training loss
            displayed in the progress bar.
        :param get_parameters_to_freeze: The (optional) callable that returns the set of parameters to freeze during training.
        :param parallel_minibatch_creation: If True the minibatches will be created in a separate thread.
        """
        if initialize_metadata:
            self.__load_metadata(training_data)

        self.LOGGER.info('Model has %s parameters',
                         self.__model.num_parameters())
        self.LOGGER.debug('Data Tensorization Started...')

        def data_to_tensor_iterator(data):
            for datapoint in data:
                tensorized_datapoint = self.__model.load_data_from_sample(
                    datapoint)
                if tensorized_datapoint is not None:
                    yield tensorized_datapoint

        def training_tensors():
            yield from ThreadedIterator(
                original_iterator=data_to_tensor_iterator(training_data),
                max_queue_size=10 * self.__minibatch_size)

        def validation_tensors():
            yield from ThreadedIterator(
                original_iterator=data_to_tensor_iterator(validation_data),
                max_queue_size=10 * self.__minibatch_size)

        def minibatch_iterator(
                data_iterator: Iterator[TensorizedData],
                return_partial_minibatches: bool = False) -> Tuple[Dict, int]:
            while True:
                mb_data, batch_is_full, num_elements = self.__model.create_minibatch(
                    data_iterator, max_num_items=self.__minibatch_size)
                if num_elements == 0:
                    break
                elif not batch_is_full and not return_partial_minibatches:
                    break  # Do not return partial minibatches when the iterator is exhausted.
                else:
                    yield mb_data, num_elements

        if device is None:
            device = torch.device(
                'cuda:0' if torch.cuda.is_available() else 'cpu')
        self.__model.to(device)
        self.LOGGER.info('Using %s for training.' % device)

        if get_parameters_to_freeze is None:
            get_parameters_to_freeze = lambda: set()
        trainable_parameters = set(
            self.__model.parameters()) - get_parameters_to_freeze()
        optimizer = self.__create_optimizer(trainable_parameters)
        scheduler = None if self.__create_scheduler is None else self.__create_scheduler(
            optimizer)

        for hook in self.__training_start_hooks:
            hook(self.__model, optimizer)

        best_loss = float('inf')  # type: float
        num_epochs_not_improved = 0  # type: int
        for epoch in range(self.__max_num_epochs):
            self.__model.train()

            data_iter = shuffled_iterator(training_tensors())
            sum_epoch_loss = 0.0
            running_avg_loss = 0.0
            num_minibatches = 0
            num_samples = 0

            start_time = time.time()
            self.__model.reset_metrics()
            with tqdm(desc='Training',
                      disable=not show_progress_bar,
                      leave=False) as progress_bar:
                for step_idx, (mb_data, num_elements) in enumerate(
                        ThreadedIterator(minibatch_iterator(
                            data_iter, return_partial_minibatches=False),
                                         enabled=parallel_minibatch_creation)):
                    optimizer.zero_grad()
                    mb_loss = self.__model(**mb_data)
                    mb_loss.backward()

                    optimizer.step()
                    if scheduler is not None:
                        scheduler.step(epoch_idx=epoch, epoch_step=step_idx)

                    loss = float(mb_loss.cpu())
                    if math.isnan(loss):
                        raise Exception('Training Loss has a NaN value.')

                    sum_epoch_loss += loss
                    num_minibatches += 1
                    num_samples += num_elements

                    if num_minibatches == 1:  # First minibatch
                        running_avg_loss = loss
                    else:
                        running_avg_loss = exponential_running_average_factor * running_avg_loss + (
                            1 - exponential_running_average_factor) * loss
                    progress_bar.update()
                    progress_bar.set_postfix(Loss=f'{running_avg_loss:.2f}')

            elapsed_time = time.time() - start_time  # type: float
            self.LOGGER.info('Training complete in %.1fsec [%.2f samples/sec]',
                             elapsed_time, (num_samples / elapsed_time))
            assert num_minibatches > 0, 'No training minibatches were created. The minibatch size may be too large or the training dataset size too small.'
            self.LOGGER.info('Epoch %i: Avg Train Loss %.2f', epoch + 1,
                             sum_epoch_loss / num_minibatches)
            train_metrics = self.__model.report_metrics()
            for epoch_hook in self.__train_epoch_end_hooks:
                epoch_hook(self.__model, epoch, train_metrics)
            if len(train_metrics) > 0:
                self.LOGGER.info('Training Metrics: %s',
                                 json.dumps(train_metrics, indent=2))

            # Now do validation!
            self.__model.eval()
            data_iter = validation_tensors()
            sum_epoch_loss = 0
            num_minibatches = 0
            num_samples = 0
            start_time = time.time()
            self.__model.reset_metrics()
            with tqdm(desc='Validation',
                      disable=not show_progress_bar,
                      leave=False) as progress_bar, torch.no_grad():
                for mb_data, num_elements in ThreadedIterator(
                        minibatch_iterator(data_iter,
                                           return_partial_minibatches=True),
                        enabled=parallel_minibatch_creation):
                    mb_loss = self.__model(**mb_data)

                    loss = float(mb_loss.cpu())
                    if math.isnan(loss):
                        raise Exception('Validation Loss has a NaN value.')

                    sum_epoch_loss += loss
                    num_minibatches += 1
                    num_samples += num_elements

                    progress_bar.update()
                    progress_bar.set_postfix(
                        Loss=f'{sum_epoch_loss / num_minibatches:.2f}')

            elapsed_time = time.time() - start_time
            assert num_samples > 0, 'No validation data was found.'
            validation_loss = sum_epoch_loss / num_minibatches
            self.LOGGER.info(
                'Validation complete in %.1fsec [%.2f samples/sec]',
                elapsed_time, (num_samples / elapsed_time))
            self.LOGGER.info('Epoch %i: Avg Valid Loss %.2f', epoch + 1,
                             validation_loss)
            validation_metrics = self.__model.report_metrics()
            for epoch_hook in self.__validation_epoch_end_hooks:
                epoch_hook(self.__model, epoch, validation_metrics)
            if len(validation_metrics) > 0:
                self.LOGGER.info('Validation Metrics: %s',
                                 json.dumps(validation_metrics, indent=2))

            if validation_loss < best_loss:
                self.LOGGER.info('Best loss so far --- Saving model.')
                num_epochs_not_improved = 0
                self.__save_current_model()
                best_loss = validation_loss
            else:
                num_epochs_not_improved += 1
                if num_epochs_not_improved > patience:
                    self.LOGGER.warning(
                        'After %s epochs loss has not improved. Stopping.',
                        num_epochs_not_improved)
                    break

        # Restore the best model that was found.
        self.restore_model()
Exemplo n.º 8
0
 def validation_tensors():
     yield from ThreadedIterator(
         original_iterator=(self.__model.load_data_from_sample(d)
                            for d in validation_data),
         max_queue_size=10 * self.__minibatch_size)
Exemplo n.º 9
0
    def __run_epoch(self,
                    epoch_name: str,
                    data: Iterable[Any],
                    data_fold: DataFold,
                    quiet: Optional[bool] = False,
                    # summary_writer: Optional[tf.summary.SummaryWriter] = None) \
                    summary_writer: Optional[tf.compat.v1.summary.FileWriter] = None) \
            -> Tuple[float, List[Dict[str, Any]], int, float, float, float]:
        batch_iterator = self.task.make_minibatch_iterator(
            data, data_fold, self.__placeholders,
            self.params['max_nodes_in_batch'])
        batch_iterator = ThreadedIterator(batch_iterator, max_queue_size=5)
        task_metric_results = []
        start_time = time.time()
        processed_graphs, processed_nodes, processed_edges = 0, 0, 0
        epoch_loss = 0.0
        epoch_accuracy = 0.0
        for step, batch_data in enumerate(batch_iterator):
            if data_fold == DataFold.TRAIN:
                batch_data.feed_dict[self.__placeholders['graph_layer_input_dropout_keep_prob']] = \
                    self.params['graph_layer_input_dropout_keep_prob']
            batch_data.feed_dict[
                self.__placeholders['num_graphs']] = batch_data.num_graphs
            # Collect some statistics:
            processed_graphs += batch_data.num_graphs
            processed_nodes += batch_data.num_nodes
            processed_edges += batch_data.num_edges

            fetch_dict = {'task_metrics': self.__ops['task_metrics']}
            if summary_writer:
                fetch_dict['tf_summaries'] = self.__ops['tf_summaries']
                fetch_dict['total_num_graphs'] = self.__ops['total_num_graphs']
            if data_fold == DataFold.TRAIN:
                fetch_dict['train_step'] = self.__ops['train_step']
            fetch_results = self.sess.run(fetch_dict,
                                          feed_dict=batch_data.feed_dict)
            epoch_loss += fetch_results['task_metrics'][
                'loss'] * batch_data.num_graphs
            epoch_accuracy += fetch_results['task_metrics'][
                'accuracy'] * batch_data.num_graphs
            task_metric_results.append(fetch_results['task_metrics'])

            if not quiet:
                print(
                    "Running %s, batch %i (has %i graphs). Loss so far: %.4f. Accuracy so far: %.4f."
                    % (epoch_name, step, batch_data.num_graphs,
                       epoch_loss / processed_graphs,
                       epoch_accuracy / processed_graphs * 100),
                    end='\r')
            if summary_writer:
                summary_writer.add_summary(fetch_results['tf_summaries'],
                                           fetch_results['total_num_graphs'])

        assert processed_graphs > 0, "Can't run epoch over empty dataset."

        epoch_time = time.time() - start_time
        per_graph_loss = epoch_loss / processed_graphs
        graphs_per_sec = processed_graphs / epoch_time
        nodes_per_sec = processed_nodes / epoch_time
        edges_per_sec = processed_edges / epoch_time
        return per_graph_loss, task_metric_results, processed_graphs, graphs_per_sec, nodes_per_sec, edges_per_sec