Ejemplo n.º 1
0
    def test_container(self):
        net = torch.nn.Conv2d(1, 1, 3, padding=1)

        opt = torch.optim.Adam(net.parameters())

        img = torch.rand(1, 16, 16)
        data = {CommonKeys.IMAGE: img, CommonKeys.LABEL: img}
        loader = DataLoader([data for _ in range(10)])

        trainer = SupervisedTrainer(
            device=torch.device("cpu"),
            max_epochs=1,
            train_data_loader=loader,
            network=net,
            optimizer=opt,
            loss_function=torch.nn.L1Loss(),
        )

        con = ThreadContainer(trainer)
        con.start()
        time.sleep(1)  # wait for trainer to start

        self.assertTrue(con.is_alive)
        self.assertIsNotNone(con.status())
        self.assertTrue(len(con.status_dict) > 0)

        con.join()
Ejemplo n.º 2
0
    def test_plot(self):
        set_determinism(0)

        testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   "testing_data")

        net = torch.nn.Conv2d(1, 1, 3, padding=1)

        opt = torch.optim.Adam(net.parameters())

        img = torch.rand(1, 16, 16)

        # a third non-image key is added to test that this is correctly ignored when plotting
        data = {
            CommonKeys.IMAGE: img,
            CommonKeys.LABEL: img,
            "Not Image Data": ["This isn't an image"]
        }

        loader = DataLoader([data] * 10)

        trainer = SupervisedTrainer(
            device=torch.device("cpu"),
            max_epochs=1,
            train_data_loader=loader,
            network=net,
            optimizer=opt,
            loss_function=torch.nn.L1Loss(),
        )

        logger = MetricLogger()
        logger.attach(trainer)

        con = ThreadContainer(trainer)
        con.start()
        con.join()

        fig = con.plot_status(logger)

        with tempfile.TemporaryDirectory() as tempdir:
            tempimg = f"{tempdir}/threadcontainer_plot_test.png"
            fig.savefig(tempimg)
            comp = compare_images(
                f"{testing_dir}/threadcontainer_plot_test.png", tempimg, 1e-3)

            self.assertIsNone(comp, comp)  # None indicates test passed