Example #1
0
    def load_weights(self):
        """
        Use `layers` and Loader to load .weights file
        """
        print ('Loading {} ...'.format(self.src_bin))
        start = time.time()

        args = [self.src_bin, self.src_layers]
        wgts_loader = loader.create_loader(*args)
        for layer in self.layers: layer.load(wgts_loader)
        
        stop = time.time()
        print ('Finished in {}s'.format(stop - start))
Example #2
0
def load_old_graph(self, ckpt):
    ckpt_loader = create_loader(ckpt)
    print old_graph_msg.format(ckpt)

    for var in tf.all_variables():
        name = var.name.split(':')[0]
        args = [name, var.get_shape()]
        val = ckpt_loader(args)
        assert val is not None, \
        'Cannot find and load {}'.format(var.name)
        shp = val.shape
        plh = tf.placeholder(tf.float32, shp)
        op = tf.assign(var, plh)
        self.sess.run(op, {plh: val})