コード例 #1
0
    def test_dynamic_plotting(self):
        plotter = Plotter()
        max = 3000
        for i in range(max):
            plotter.add_values([("loss", (max-i)/max), ("evaluation score", i/max/2), ("second score", 0.3)])

        plotter.plot("DynamicTestPlot").savefig("DynamicTestPlot")
        self.assertTrue(os.path.exists("DynamicTestPlot.png"))
class Experiment(ABC):
    """ Base class for running experiments. Provides a plotter as well as path handling. DO NOT FORGET TO CALL super()"""
    def __init__(self):
        self.experiment_name = self.__class__.__name__
        self.__plotter__ = Plotter()
        self.last_plot = None

        self.path = self.config.TIC_TAC_TOE_DIR + "/experiments/artifacts/%s/" % self.experiment_name

    @abstractmethod
    def run(self, silent=False):
        pass

    @abstractmethod
    def reset(self):
        pass

    @classmethod
    def load_player(cls, player_name):
        filename = cls.config.find_in_subdirectory(
            player_name, cls.config.TIC_TAC_TOE_DIR + "/experiments")
        return torch.load(filename)

    def add_results(self, results):
        """
        Takes a single tuple or a list of tuples (name, value) and appends them to the internal plotter.
        Each distinct name is plotted as separately with its values interpolated to fit the other values.

        :param results: a list of tuples (name, value)
        :return: None
        """
        if not self.__plotter__:
            raise Exception(
                "__plotter__ not initialized, Experiment's super() must be called"
            )
        try:
            if isinstance(results, list):
                self.__plotter__.add_values(results)
            elif isinstance(results, tuple):
                self.__plotter__.add_values([results])
        except Exception as e:
            raise Exception("add_result received an illegal argument: " +
                            str(e))

    def add_loss(self, loss):
        if not self.__plotter__:
            raise Exception(
                "__plotter__ not initialized, Experiment's super() must be called"
            )
        self.__plotter__.add_loss(loss)

    def plot_scores(self, title):
        if not self.__plotter__:
            raise Exception(
                "__plotter__ not initialized, Experiment's super() must be called"
            )
        self.last_plot = self.__plotter__.plot(title)

    def plot_and_save(self, file_name, plot_title=""):
        self.plot_scores(plot_title if plot_title else file_name)

        if not os.path.exists(self.path):
            os.makedirs(self.path)

        self.last_plot.savefig(self.path + file_name + ".png")
        self.last_plot.close("all")

    def save_player(self, player, description=""):
        if not os.path.exists(self.path):
            os.makedirs(self.path)

        torch.save(player,
                   self.path + player.__str__() + " " + description + ".pth")

    @property
    def num_episodes(self):
        return self.__plotter__.num_episodes

    def __str__(self):
        return self.__class__.__name__

    class AlternatingColorIterator:
        """
        Returns Black and White alternately, starting with WHITE
        """
        def __init__(self):
            from Othello.config import BLACK, WHITE
            self.colors = [BLACK, WHITE]

        def __iter__(self):
            return self

        def __next__(self):
            self.colors = list(reversed(self.colors))
            return self.colors[0]