예제 #1
0
파일: agent.py 프로젝트: pazocal/chainerrl
def load_npz_no_strict(filename, obj):
    try:
        serializers.load_npz(filename, obj)
    except KeyError as e:
        warnings.warn(repr(e))
        with numpy.load(filename) as f:
            d = serializers.NpzDeserializer(f, strict=False)
            d.load(obj)
예제 #2
0
파일: train.py 프로젝트: nojima/workspace
def load_model(filename: str) -> Tuple[Word2Vec, HyperParameters]:
    with np.load(filename) as f:
        deserializer = serializers.NpzDeserializer(f)

        pickled_params = deserializer("hyper_parameters", None)
        params = pickle.loads(
            pickled_params.tobytes())  # type: HyperParameters

        model = _make_model(params)
        deserializer["model"].load(model)

        return model, params
예제 #3
0
    def test_iterator_serialize_backward_compat(self):
        dataset = [1, 2, 3, 4, 5, 6]
        it = iterators.SerialIterator(dataset,
                                      2,
                                      shuffle=self.shuffle,
                                      order_sampler=self.order_sampler)

        self.assertEqual(it.epoch, 0)
        self.assertAlmostEqual(it.epoch_detail, 0 / 6)
        self.assertIsNone(it.previous_epoch_detail)
        batch1 = it.next()
        self.assertEqual(len(batch1), 2)
        self.assertIsInstance(batch1, list)
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 2 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 0 / 6)
        batch2 = it.next()
        self.assertEqual(len(batch2), 2)
        self.assertIsInstance(batch2, list)
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 4 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6)

        target = dict()
        it.serialize(serializers.DictionarySerializer(target))
        # older version uses '_order'
        target['_order'] = target['order']
        del target['order']
        # older version does not have previous_epoch_detail
        del target['previous_epoch_detail']

        it = iterators.SerialIterator(dataset, 2)
        it.serialize(serializers.NpzDeserializer(target))
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 4 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6)

        batch3 = it.next()
        self.assertEqual(len(batch3), 2)
        self.assertIsInstance(batch3, list)
        self.assertTrue(it.is_new_epoch)
        self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
        self.assertAlmostEqual(it.epoch_detail, 6 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 4 / 6)
예제 #4
0
def load_model(filename: str) -> EncoderDecoder:
    with np.load(filename) as f:
        d = dict(f.iteritems())

        # 後方互換性のために _W を _extract_output に改名する
        for k, v in list(d.items()):
            old_prefix = "model/_W/"
            new_prefix = "model/_extract_output/"
            if k.startswith(old_prefix):
                d[new_prefix + k[len(old_prefix):]] = v

        deserializer = serializers.NpzDeserializer(d)

        pickled_params = deserializer("hyper_parameters", None)
        params = pickle.loads(pickled_params.tobytes())

        model = EncoderDecoder(*params)
        deserializer["model"].load(model)

        return model
예제 #5
0
    def test_iterator_compatibilty(self):
        dataset = [1, 2, 3, 4, 5, 6]

        iters = (
            lambda: iterators.SerialIterator(dataset, 2),
            lambda: iterators.MultiprocessIterator(dataset, 2, **self.options),
        )

        for it_before, it_after in itertools.permutations(iters, 2):
            it = it_before()

            self.assertEqual(it.epoch, 0)
            self.assertAlmostEqual(it.epoch_detail, 0 / 6)
            batch1 = it.next()
            self.assertEqual(len(batch1), 2)
            self.assertIsInstance(batch1, list)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, 2 / 6)
            batch2 = it.next()
            self.assertEqual(len(batch2), 2)
            self.assertIsInstance(batch2, list)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, 4 / 6)

            target = dict()
            it.serialize(serializers.DictionarySerializer(target))

            it = it_after()
            it.serialize(serializers.NpzDeserializer(target))
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, 4 / 6)

            batch3 = it.next()
            self.assertEqual(len(batch3), 2)
            self.assertIsInstance(batch3, list)
            self.assertTrue(it.is_new_epoch)
            self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
            self.assertAlmostEqual(it.epoch_detail, 6 / 6)
예제 #6
0
    def test_iterator_serialize(self):
        dataset = [1, 2, 3, 4, 5, 6]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)

        self.assertEqual(it.epoch, 0)
        self.assertAlmostEqual(it.epoch_detail, 0 / 6)
        self.assertIsNone(it.previous_epoch_detail)
        batch1 = it.next()
        self.assertEqual(len(batch1), 2)
        self.assertIsInstance(batch1, list)
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 2 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 0 / 6)
        batch2 = it.next()
        self.assertEqual(len(batch2), 2)
        self.assertIsInstance(batch2, list)
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 4 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6)

        target = dict()
        it.serialize(serializers.DictionarySerializer(target))

        it = iterators.MultiprocessIterator(dataset, 2, **self.options)
        it.serialize(serializers.NpzDeserializer(target))
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 4 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6)

        batch3 = it.next()
        self.assertEqual(len(batch3), 2)
        self.assertIsInstance(batch3, list)
        self.assertTrue(it.is_new_epoch)
        self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
        self.assertAlmostEqual(it.epoch_detail, 6 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 4 / 6)
예제 #7
0
파일: train.py 프로젝트: strategist922/knmt
def load_model_flexible(filename_list, encdec):
    mode = "normal"
    if isinstance(filename_list, tuple) or isinstance(filename_list, list):
        if len(filename_list) == 1:
            filename_list = filename_list[0]
        else:
            mode = "average"

    if mode == "normal":
        log.info("loading model parameters from %s", filename_list)
        try:
            serializers.load_npz(filename_list, encdec)
        except KeyError:
            log.info("not model format, trying snapshot format")
            with np.load(filename_list) as fseri:
                dicseri = serializers.NpzDeserializer(
                    fseri, path="updater/model:main/")
                dicseri.load(encdec)
    else:
        assert mode == "average"
        log.info("loading averaged model parameters from %r", filename_list)
        dseri = NpzDeserializerAverage(
            [np.load(filename) for filename in filename_list])
        dseri.load(encdec)
예제 #8
0
def load_npz(filename, obj):
    with np.load(filename) as f:
        d = serializers.NpzDeserializer(f, strict=False)
        d.load(obj)
def load_from_csv(directory, model):
    dic_params = csv_to_npz(directory)
    d = serializers.NpzDeserializer(dic_params)
    d.load(model)
    def load_checkpoint(self,
                        save_path,
                        epoch=-1,
                        restart=False,
                        load_pretrained_model=False):
        """Load checkpoint.
        Args:
            save_path (string): path to the saved models
            epoch (int, optional): if -1 means the last saved model
            restart (bool, optional): if True, restore the save optimizer
            load_pretrained_model (bool, optional): if True, load all parameters
                which match those of the new model's parameters
        Returns:
            epoch (int): the currnet epoch
            step (int): the current step
            lr (float):
            metric_dev_best (float):
        """
        if int(epoch) == -1:
            # Restore the last saved model
            epochs = [(int(basename(x).split('-')[-1].split('.')[0]), x)
                      for x in glob(join(save_path, 'model.*'))]

            if len(epochs) == 0:
                raise ValueError

            epoch = sorted(epochs, key=lambda x: x[0])[-1][0]

        model_path = join(save_path, 'model.epoch-' + str(epoch) + '.npz')

        if isfile(join(model_path)):
            with np.load(model_path) as f:
                deserializer = serializers.NpzDeserializer(f)

                pickled_params = deserializer("checkpoint", None)
                checkpoint = pickle.loads(pickled_params.tobytes())
                # type: HyperParameters

                # Restore parameters
                if load_pretrained_model:
                    logger.info(
                        "=> Loading pre-trained checkpoint (epoch:%d): %s" %
                        (epoch, model_path))

                    # TODO:
                    # pretrained_dict = checkpoint['state_dict']
                    # model_dict = self.state_dict()
                    #
                    # # 1. filter out unnecessary keys and params which do not match size
                    # pretrained_dict = {
                    #     k: v for k, v in pretrained_dict.items() if k in model_dict.keys() and v.size() == model_dict[k].size()}
                    # # 2. overwrite entries in the existing state dict
                    # model_dict.update(pretrained_dict)
                    # # 3. load the new state dict
                    # self.load_state_dict(model_dict)

                    # for k in pretrained_dict.keys():
                    #     logger.info(k)

                deserializer["model"].load(self)

                # Restore optimizer
                if restart:
                    if hasattr(self, 'optimizer'):
                        deserializer["optimizer"].load(self.optimizer)
                    else:
                        raise ValueError('Set optimizer.')
                else:
                    print("=> Loading checkpoint (epoch:%d): %s" %
                          (epoch, model_path))

        else:
            raise ValueError("No checkpoint found at %s" % model_path)

        return (checkpoint['epoch'] + 1, checkpoint['step'] + 1,
                checkpoint['lr'], checkpoint['metric_dev_best'])
    else:
        vae = ais.AIS(decoder,
                      M=zcount,
                      T=ais_temps,
                      steps=ais_steps,
                      stepsize=ais_stepsize,
                      sigma=ais_sigma,
                      encoder=encoder)
else:
    sys.exit("Unsupported VAE type")

#serializers.load_hdf5(model_file, vae)
#serializers.load_npz(model_file, vae, path='updater/model:main/')
try:
    with np.load(model_file) as f:
        d = serializers.NpzDeserializer(f, path='updater/model:main/')
        d.load(vae)
except:
    with np.load(model_file) as f:
        d = serializers.NpzDeserializer(f, path='updater/model:elbo/')
        d.load(vae)

print "Deserialized model '%s' of type '%s'" % (model_file, vae_type_train)

if gpu_id >= 0:
    vae.to_gpu(gpu_id)
print "Moved model to GPU %d" % gpu_id

# For debugging purposes, optionally, obtain and write logw value for the
# first few test samples
if '--logw' in args and args['--logw'] is not None: