def create_tsm_solution(self, epoch: int):
     solution_map = defaultdict(list)
     for i, feature in enumerate(self.features):
         winner = Utilities.get_winning_neuron_2d(feature, self.weights)
         solution_map[winner].append(i)
     solution_indices = []
     for key in sorted(solution_map.keys(), key=lambda tup: tup[1]):
         solution_indices.extend(solution_map[key])
     normalized_solution = list(
         map(lambda i: self.features[i], solution_indices))
     solution = list(map(lambda i: self.originals[i], solution_indices))
     total = 0
     for i in range(len(solution) - 1):
         total += Utilities.euclidian_distance(solution[i], solution[i + 1])
     total += Utilities.euclidian_distance(solution[-1], solution[0])
     if self.should_display:
         self.tsm_visualizer.update_solution(tensor(normalized_solution),
                                             total, epoch)
     return total
    def run(self):

        n_cases_to_run = self.n_epochs * len(self.features)
        counter = 0

        memory = None
        radius = None
        l_rate = None

        print("\nStarting Training Session\n")

        for i in range(self.n_epochs):
            if Utilities.time_to_visualize(i, self.display_interval,
                                           self.n_epochs) and self.mnist:
                memory = [[[] for _ in range(self.n_output_cols)]
                          for _ in range(self.n_output_rows)]

            radius = int(round(self.initial_radius *
                               self.radius_decay_func(i)))
            l_rate = self.initial_l_rate * self.l_rate_decay_func(i)

            for j, case in enumerate(self.features):
                row, col = Utilities.get_winning_neuron_2d(case, self.weights)
                Utilities.update_weight_matrix_2d(case, l_rate, row, col,
                                                  self.weights)
                if self.mnist:
                    neighbours = self.generate_neighbour_coordinates(
                        row, col, radius)
                else:
                    neighbours = self.generate_tsm_neighbours(col, radius)
                for neighbour in neighbours:
                    if radius:
                        if self.mnist:
                            dist = Utilities.euclidian_distance(
                                tensor(neighbour), tensor((row, col)))
                        else:
                            dist = Utilities.ring_distance(
                                neighbour, (row, col), self.n_output_cols)
                        influence = math.exp(-(dist / (2 * radius**2)))
                        Utilities.update_weight_matrix_2d(
                            case, influence * l_rate, neighbour[0],
                            neighbour[1], self.weights)
                if Utilities.time_to_visualize(i, self.display_interval,
                                               self.n_epochs) and self.mnist:
                    memory[row][col].append(self.labels[j])

                counter += 1
                message = "Epoch: %d \t L_Rate: %.3f \t Radius: %3d" % (
                    i, l_rate, radius)
                Utilities.print_progress(n_cases_to_run, counter,
                                         message) if j % 10 == 0 else NoOp

            if self.should_display and Utilities.time_to_visualize(
                    i, self.display_interval, self.n_epochs):
                if self.mnist:
                    reduced_memory = Utilities.reduce_memory(memory)
                    plot_mnist_color(reduced_memory, i)
                else:
                    self.create_tsm_solution(i)
                    self.tsm_visualizer.update_weights(self.weights)
        message = "Epoch: %d \t L_Rate: %.3f \t Radius: %3d" % (
            self.n_epochs - 1, l_rate, radius)
        Utilities.print_progress(1, 1, message)
        print("\n\nDone Training")

        if self.mnist:
            self.test(memory, self.weights, True)
            self.test(memory, self.weights, False)
        else:
            if self.should_display:
                self.tsm_visualizer.update_weights(self.weights)
            total = self.create_tsm_solution(self.n_epochs - 1)
            print("\nLength of best route: %.2f" % total)
            return total