def save_model(filename: str, model: EncoderDecoder) -> None: serializer = serializers.DictionarySerializer() pickled_params = np.frombuffer(pickle.dumps(model.hyper_params), dtype=np.uint8) serializer("hyper_parameters", pickled_params) serializer["model"].save(model) np.savez_compressed(filename, **serializer.target)
def save_checkpoint(self, save_path, epoch, step, lr, metric_dev_best, remove_old_checkpoints=False): """Save checkpoint. Args: save_path (string): path to save a model (directory) epoch (int): the currnet epoch step (int): the current step lr (float): metric_dev_best (float): remove_old_checkpoints (bool, optional): if True, all checkpoints other than the best one will be deleted Returns: model (string): path to the saved model (file) """ model_path = join(save_path, 'model.epoch-' + str(epoch)) # Remove old checkpoints if remove_old_checkpoints: for path in glob(join(save_path, 'model.epoch-*')): os.remove(path) checkpoint = { "epoch": epoch, "step": step, "lr": lr, "metric_dev_best": metric_dev_best } # Save parameters, optimizer, step index etc. # serializers.save_npz(model_path, self) # serializers.save_npz( # join(save_path, 'optimizer.epoch-' + str(epoch)), self.optimizer) serializer = serializers.DictionarySerializer() pickled_params = np.frombuffer(pickle.dumps(checkpoint), dtype=np.uint8) serializer("checkpoint", pickled_params) serializer["model"].save(self) serializer["optimizer"].save(self.optimizer) np.savez_compressed(model_path, **serializer.target) logger.info("=> Saved checkpoint (epoch:%d): %s" % (epoch, model_path))
def test_iterator_serialize_backward_compat(self): dataset = [1, 2, 3, 4, 5, 6] it = iterators.SerialIterator(dataset, 2, shuffle=self.shuffle, order_sampler=self.order_sampler) self.assertEqual(it.epoch, 0) self.assertAlmostEqual(it.epoch_detail, 0 / 6) self.assertIsNone(it.previous_epoch_detail) batch1 = it.next() self.assertEqual(len(batch1), 2) self.assertIsInstance(batch1, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 2 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 0 / 6) batch2 = it.next() self.assertEqual(len(batch2), 2) self.assertIsInstance(batch2, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6) target = dict() it.serialize(serializers.DictionarySerializer(target)) # older version uses '_order' target['_order'] = target['order'] del target['order'] # older version does not have previous_epoch_detail del target['previous_epoch_detail'] it = iterators.SerialIterator(dataset, 2) it.serialize(serializers.NpzDeserializer(target)) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6) batch3 = it.next() self.assertEqual(len(batch3), 2) self.assertIsInstance(batch3, list) self.assertTrue(it.is_new_epoch) self.assertEqual(sorted(batch1 + batch2 + batch3), dataset) self.assertAlmostEqual(it.epoch_detail, 6 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 4 / 6)
def test_iterator_compatibilty(self): dataset = [1, 2, 3, 4, 5, 6] iters = ( lambda: iterators.SerialIterator(dataset, 2), lambda: iterators.MultiprocessIterator(dataset, 2, **self.options), ) for it_before, it_after in itertools.permutations(iters, 2): it = it_before() self.assertEqual(it.epoch, 0) self.assertAlmostEqual(it.epoch_detail, 0 / 6) batch1 = it.next() self.assertEqual(len(batch1), 2) self.assertIsInstance(batch1, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 2 / 6) batch2 = it.next() self.assertEqual(len(batch2), 2) self.assertIsInstance(batch2, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) target = dict() it.serialize(serializers.DictionarySerializer(target)) it = it_after() it.serialize(serializers.NpzDeserializer(target)) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) batch3 = it.next() self.assertEqual(len(batch3), 2) self.assertIsInstance(batch3, list) self.assertTrue(it.is_new_epoch) self.assertEqual(sorted(batch1 + batch2 + batch3), dataset) self.assertAlmostEqual(it.epoch_detail, 6 / 6)
def test_iterator_serialize(self): dataset = [1, 2, 3, 4, 5, 6] it = iterators.MultiprocessIterator(dataset, 2, **self.options) self.assertEqual(it.epoch, 0) self.assertAlmostEqual(it.epoch_detail, 0 / 6) self.assertIsNone(it.previous_epoch_detail) batch1 = it.next() self.assertEqual(len(batch1), 2) self.assertIsInstance(batch1, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 2 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 0 / 6) batch2 = it.next() self.assertEqual(len(batch2), 2) self.assertIsInstance(batch2, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6) target = dict() it.serialize(serializers.DictionarySerializer(target)) it = iterators.MultiprocessIterator(dataset, 2, **self.options) it.serialize(serializers.NpzDeserializer(target)) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6) batch3 = it.next() self.assertEqual(len(batch3), 2) self.assertIsInstance(batch3, list) self.assertTrue(it.is_new_epoch) self.assertEqual(sorted(batch1 + batch2 + batch3), dataset) self.assertAlmostEqual(it.epoch_detail, 6 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 4 / 6)
def save_to_csv(directory, obj): s = serializers.DictionarySerializer() s.save(obj) target = s.target npz_to_csv(target)