Esempio n. 1
0
def save_model(filename: str, model: EncoderDecoder) -> None:
    serializer = serializers.DictionarySerializer()

    pickled_params = np.frombuffer(pickle.dumps(model.hyper_params), dtype=np.uint8)
    serializer("hyper_parameters", pickled_params)

    serializer["model"].save(model)

    np.savez_compressed(filename, **serializer.target)
    def save_checkpoint(self,
                        save_path,
                        epoch,
                        step,
                        lr,
                        metric_dev_best,
                        remove_old_checkpoints=False):
        """Save checkpoint.
        Args:
            save_path (string): path to save a model (directory)
            epoch (int): the currnet epoch
            step (int): the current step
            lr (float):
            metric_dev_best (float):
            remove_old_checkpoints (bool, optional): if True, all checkpoints
                other than the best one will be deleted
        Returns:
            model (string): path to the saved model (file)
        """
        model_path = join(save_path, 'model.epoch-' + str(epoch))

        # Remove old checkpoints
        if remove_old_checkpoints:
            for path in glob(join(save_path, 'model.epoch-*')):
                os.remove(path)

        checkpoint = {
            "epoch": epoch,
            "step": step,
            "lr": lr,
            "metric_dev_best": metric_dev_best
        }

        # Save parameters, optimizer, step index etc.
        # serializers.save_npz(model_path, self)
        # serializers.save_npz(
        #     join(save_path, 'optimizer.epoch-' + str(epoch)), self.optimizer)

        serializer = serializers.DictionarySerializer()

        pickled_params = np.frombuffer(pickle.dumps(checkpoint),
                                       dtype=np.uint8)
        serializer("checkpoint", pickled_params)

        serializer["model"].save(self)
        serializer["optimizer"].save(self.optimizer)

        np.savez_compressed(model_path, **serializer.target)

        logger.info("=> Saved checkpoint (epoch:%d): %s" % (epoch, model_path))
Esempio n. 3
0
    def test_iterator_serialize_backward_compat(self):
        dataset = [1, 2, 3, 4, 5, 6]
        it = iterators.SerialIterator(dataset,
                                      2,
                                      shuffle=self.shuffle,
                                      order_sampler=self.order_sampler)

        self.assertEqual(it.epoch, 0)
        self.assertAlmostEqual(it.epoch_detail, 0 / 6)
        self.assertIsNone(it.previous_epoch_detail)
        batch1 = it.next()
        self.assertEqual(len(batch1), 2)
        self.assertIsInstance(batch1, list)
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 2 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 0 / 6)
        batch2 = it.next()
        self.assertEqual(len(batch2), 2)
        self.assertIsInstance(batch2, list)
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 4 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6)

        target = dict()
        it.serialize(serializers.DictionarySerializer(target))
        # older version uses '_order'
        target['_order'] = target['order']
        del target['order']
        # older version does not have previous_epoch_detail
        del target['previous_epoch_detail']

        it = iterators.SerialIterator(dataset, 2)
        it.serialize(serializers.NpzDeserializer(target))
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 4 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6)

        batch3 = it.next()
        self.assertEqual(len(batch3), 2)
        self.assertIsInstance(batch3, list)
        self.assertTrue(it.is_new_epoch)
        self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
        self.assertAlmostEqual(it.epoch_detail, 6 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 4 / 6)
Esempio n. 4
0
    def test_iterator_compatibilty(self):
        dataset = [1, 2, 3, 4, 5, 6]

        iters = (
            lambda: iterators.SerialIterator(dataset, 2),
            lambda: iterators.MultiprocessIterator(dataset, 2, **self.options),
        )

        for it_before, it_after in itertools.permutations(iters, 2):
            it = it_before()

            self.assertEqual(it.epoch, 0)
            self.assertAlmostEqual(it.epoch_detail, 0 / 6)
            batch1 = it.next()
            self.assertEqual(len(batch1), 2)
            self.assertIsInstance(batch1, list)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, 2 / 6)
            batch2 = it.next()
            self.assertEqual(len(batch2), 2)
            self.assertIsInstance(batch2, list)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, 4 / 6)

            target = dict()
            it.serialize(serializers.DictionarySerializer(target))

            it = it_after()
            it.serialize(serializers.NpzDeserializer(target))
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, 4 / 6)

            batch3 = it.next()
            self.assertEqual(len(batch3), 2)
            self.assertIsInstance(batch3, list)
            self.assertTrue(it.is_new_epoch)
            self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
            self.assertAlmostEqual(it.epoch_detail, 6 / 6)
Esempio n. 5
0
    def test_iterator_serialize(self):
        dataset = [1, 2, 3, 4, 5, 6]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)

        self.assertEqual(it.epoch, 0)
        self.assertAlmostEqual(it.epoch_detail, 0 / 6)
        self.assertIsNone(it.previous_epoch_detail)
        batch1 = it.next()
        self.assertEqual(len(batch1), 2)
        self.assertIsInstance(batch1, list)
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 2 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 0 / 6)
        batch2 = it.next()
        self.assertEqual(len(batch2), 2)
        self.assertIsInstance(batch2, list)
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 4 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6)

        target = dict()
        it.serialize(serializers.DictionarySerializer(target))

        it = iterators.MultiprocessIterator(dataset, 2, **self.options)
        it.serialize(serializers.NpzDeserializer(target))
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 4 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6)

        batch3 = it.next()
        self.assertEqual(len(batch3), 2)
        self.assertIsInstance(batch3, list)
        self.assertTrue(it.is_new_epoch)
        self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
        self.assertAlmostEqual(it.epoch_detail, 6 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 4 / 6)
def save_to_csv(directory, obj):
    s = serializers.DictionarySerializer()
    s.save(obj)
    target = s.target
    npz_to_csv(target)