コード例 #1
0
ファイル: mobilenetv1.py プロジェクト: vermicelli/Maix-EMC
def restore_params(network, alpha, path='models'):
    logging.info("Restore pre-trained parameters")
    #maybe_download_and_extract(
    #    'mobilenet.npz', path, 'https://github.com/tensorlayer/pretrained-models/raw/master/models/',
    #    expected_bytes=25600116
    #)  # ls -al
    filename = "mbnetv1_" + str(alpha) + ".npz"
    params = load_npz(name=os.path.join(path, filename))

    for idx, net_weight in enumerate(network.all_weights):
        if 'batchnorm' in net_weight.name:
            params[idx] = params[idx].reshape(1, 1, 1, -1)
    # exchange batchnorm's beta and gmma (TL and keras is different)
    idx = 0
    while idx < len(network.all_weights):
        net_weight = network.all_weights[idx]
        if ('batchnorm' in net_weight.name) and ('beta' in net_weight.name):
            tmp = params[idx]
            params[idx] = params[idx + 1]
            params[idx + 1] = tmp
            idx += 2
        else:
            idx += 1

    assign_weights(params[:len(network.all_weights)], network)
    del params
コード例 #2
0
def restore_model(model, layer_type):
    logging.info("Restore pre-trained weights")
    # download weights
    maybe_download_and_extract(model_saved_name[layer_type], 'models', model_urls[layer_type])
    weights = []
    if layer_type == 'vgg16':
        npz = np.load(os.path.join('models', model_saved_name[layer_type]))
        # get weight list
        for val in sorted(npz.items()):
            logging.info("  Loading weights %s in %s" % (str(val[1].shape), val[0]))
            weights.append(val[1])
            if len(model.weights) == len(weights):
                break
    elif layer_type == 'vgg19':
        npz = np.load(os.path.join('models', model_saved_name[layer_type]), encoding='latin1').item()
        # get weight list
        for val in sorted(npz.items()):
            logging.info("  Loading %s in %s" % (str(val[1][0].shape), val[0]))
            logging.info("  Loading %s in %s" % (str(val[1][1].shape), val[0]))
            weights.extend(val[1])
            if len(model.weights) == len(weights):
                break
    # assign weight values
    assign_weights(weights, model)
    del weights
コード例 #3
0
ファイル: resnet.py プロジェクト: zuzi-rl/tensorlayer
def restore_params(network, path='models'):
    logging.info("Restore pre-trained parameters")
    maybe_download_and_extract(
        'resnet50_weights_tf_dim_ordering_tf_kernels.h5',
        path,
        'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/',
    )  # ls -al
    try:
        import h5py
    except Exception:
        raise ImportError('h5py not imported')

    f = h5py.File(
        os.path.join(path, 'resnet50_weights_tf_dim_ordering_tf_kernels.h5'),
        'r')

    for layer in network.all_layers:
        if len(layer.all_weights) == 0:
            continue
        w_names = list(f[layer.name])
        params = [f[layer.name][n][:] for n in w_names]
        # if 'bn' in layer.name:
        #     params = [x.reshape(1, 1, 1, -1) for x in params]
        assign_weights(params, layer)
        del params

    f.close()
コード例 #4
0
def restore_model(model, layer_type):
    logging.info("Restore pre-trained weights")
    # download weights
    maybe_download_and_extract(model_saved_name[layer_type], 'models', model_urls[layer_type])
    weights = []
    if layer_type == 'vgg16':
        npz = np.load(os.path.join('models', model_saved_name[layer_type]), allow_pickle=True)
        # get weight list
        for val in sorted(npz.items()):
            logging.info("  Loading weights %s in %s" % (str(val[1].shape), val[0]))
            weights.append(val[1])
            if len(model.all_weights) == len(weights):
                break
    elif layer_type == 'vgg19':
        npz = np.load(os.path.join('models', model_saved_name[layer_type]), allow_pickle=True, encoding='latin1').item()
        # get weight list
        for val in sorted(npz.items()):
            logging.info("  Loading %s in %s" % (str(val[1][0].shape), val[0]))
            logging.info("  Loading %s in %s" % (str(val[1][1].shape), val[0]))
            weights.extend(val[1])
            if len(model.all_weights) == len(weights):
                break
    else:
        raise TypeError(f'layer type not supported for restore_model(): {layer_type}')
    # assign weight values
    # UPDATE: weights must be shorter in len than model.all_weights (caller's duty to check)
    # assign_weights(weights, model)
    assign_weights(weights[:len(model.all_weights)], model)

    del weights
コード例 #5
0
ファイル: squeezenetv1.py プロジェクト: schneicw/Chatbot
def restore_params(network, path='models'):
    logging.info("Restore pre-trained parameters")
    maybe_download_and_extract(
        'squeezenet.npz', path, 'https://github.com/tensorlayer/pretrained-models/raw/master/models/',
        expected_bytes=7405613
    )  # ls -al
    params = load_npz(name=os.path.join(path, 'squeezenet.npz'))
    assign_weights(params[:len(network.all_weights)], network)
    del params
コード例 #6
0
ファイル: squeezenetv1.py プロジェクト: zsdonghao/tensorlayer
def restore_params(network, path='models'):
    logging.info("Restore pre-trained parameters")
    maybe_download_and_extract(
        'squeezenet.npz', path, 'https://github.com/tensorlayer/pretrained-models/raw/master/models/',
        expected_bytes=7405613
    )  # ls -al
    params = load_npz(name=os.path.join(path, 'squeezenet.npz'))
    assign_weights(params[:len(network.weights)], network)
    del params
コード例 #7
0
 def restore_params(self, sess, path='models'):
     logging.info("Restore pre-trained parameters")
     maybe_download_and_extract(
         'mobilenet.npz', path, 'https://github.com/tensorlayer/pretrained-models/raw/master/models/',
         expected_bytes=25600116
     )  # ls -al
     params = load_npz(name=os.path.join(path, 'mobilenet.npz'))
     assign_weights(sess, params[:len(self.net.all_params)], self.net)
     del params
コード例 #8
0
def restore_weight(net, M, layername):
    all_weights = net.all_weights
    for i in range(len(all_weights)-1):
        weights = all_weights[i]
        weights1 = all_weights[i+1]
        logging.debug(weights.name)
        if (layername in weights.name) and (layername not in weights1.name):
            break
    logging.debug(i)
    assign_weights(all_weights[0:i+1], M)
コード例 #9
0
ファイル: mobilenetv1.py プロジェクト: zhuxb/tensorlayer
def restore_params(network, path='models'):
    logging.info("Restore pre-trained parameters")
    maybe_download_and_extract(
        'mobilenet.npz',
        path,
        'https://github.com/tensorlayer/pretrained-models/raw/master/models/',
        expected_bytes=25600116)  # ls -al
    params = load_npz(name=os.path.join(path, 'mobilenet.npz'))
    for idx, net_weight in enumerate(network.all_weights):
        if 'batchnorm' in net_weight.name:
            params[idx] = params[idx].reshape(1, 1, 1, -1)
    assign_weights(params[:len(network.all_weights)], network)
    del params
コード例 #10
0
ファイル: vgg16.py プロジェクト: zsdonghao/tensorlayer
 def restore_weights(self):
     logging.info("Restore pre-trained weights")
     ## download weights
     maybe_download_and_extract(
         'vgg16_weights.npz', 'models', 'http://www.cs.toronto.edu/~frossard/vgg16/', expected_bytes=553436134
     )
     npz = np.load(os.path.join('models', 'vgg16_weights.npz'))
     ## get weight list
     weights = []
     for val in sorted(npz.items()):
         logging.info("  Loading weights %s in %s" % (str(val[1].shape), val[0]))
         weights.append(val[1])
         if len(self.weights) == len(weights):
             break
     ## assign weight values
     # print(self.weights)
     assign_weights(weights, self)
     del weights
コード例 #11
0
ファイル: vgg16.py プロジェクト: tqb4342/tensorlayer2
 def restore_weights(self):
     logging.info("Restore pre-trained weights")
     ## download weights
     maybe_download_and_extract(
         'vgg16_weights.npz',
         'models',
         'http://www.cs.toronto.edu/~frossard/vgg16/',
         expected_bytes=553436134)
     npz = np.load(os.path.join('models', 'vgg16_weights.npz'))
     ## get weight list
     weights = []
     for val in sorted(npz.items()):
         logging.info("  Loading weights %s in %s" %
                      (str(val[1].shape), val[0]))
         weights.append(val[1])
         if len(self.weights) == len(weights):
             break
     ## assign weight values
     # print(self.weights)
     assign_weights(weights, self)
     del weights
コード例 #12
0
ファイル: vgg19.py プロジェクト: tqb4342/tensorlayer2
    def restore_weights(self, sess=None):
        logging.info("Restore pre-trained weights")
        ## download weights
        maybe_download_and_extract(
            'vgg19.npy', 'models',
            'https://media.githubusercontent.com/media/tensorlayer/pretrained-models/master/models/',
            expected_bytes=574670860
        )
        vgg19_npy_path = os.path.join('models', 'vgg19.npy')
        npz = np.load(vgg19_npy_path, encoding='latin1').item()

        weights = []
        for val in sorted(npz.items()):
            W = np.asarray(val[1][0])
            b = np.asarray(val[1][1])
            print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
            weights.extend([W, b])
            if len(self.all_params) == len(weights):
                break
        ## assign weight values
        print(self.weights)
        assign_weights(sess, weights, self)
        del weights
コード例 #13
0
    def find_top_model(self, sort=None, model_name='model', **kwargs):
        """Finds and returns a model architecture and its parameters from the database which matches the requirement.

        Parameters
        ----------
        sort : List of tuple
            PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details.
        model_name : str or None
            The name/key of model.
        kwargs : other events
            Other events, such as name, accuracy, loss, step number and etc (optinal).

        Examples
        ---------
        - see ``save_model``.

        Returns
        ---------
        network : TensorLayer Model
            Note that, the returned network contains all information of the document (record), e.g. if you saved accuracy in the document, you can get the accuracy by using ``net._accuracy``.
        """
        # print(kwargs)   # {}
        kwargs.update({'model_name': model_name})
        self._fill_project_info(kwargs)

        s = time.time()

        d = self.db.Model.find_one(filter=kwargs, sort=sort)

        # _temp_file_name = '_find_one_model_ztemp_file'
        if d is not None:
            params_id = d['params_id']
            graphs = d['architecture']
            _datetime = d['time']
            # exists_or_mkdir(_temp_file_name, False)
            # with open(os.path.join(_temp_file_name, 'graph.pkl'), 'wb') as file:
            #     pickle.dump(graphs, file, protocol=pickle.HIGHEST_PROTOCOL)
        else:
            print("[Database] FAIL! Cannot find model: {}".format(kwargs))
            return False
        try:
            params = self._deserialization(self.model_fs.get(params_id).read())
            # TODO : restore model and load weights
            network = static_graph2net(graphs)
            assign_weights(weights=params, network=network)
            # np.savez(os.path.join(_temp_file_name, 'params.npz'), params=params)
            #
            # network = load_graph_and_params(name=_temp_file_name, sess=sess)
            # del_folder(_temp_file_name)

            pc = self.db.Model.find(kwargs)
            print(
                "[Database] Find one model SUCCESS. kwargs:{} sort:{} save time:{} took: {}s".format(
                    kwargs, sort, _datetime, round(time.time() - s, 2)
                )
            )

            # FIXME : not sure what's this for
            # put all informations of model into the TL layer
            # for key in d:
            #     network.__dict__.update({"_%s" % key: d[key]})

            # check whether more parameters match the requirement
            params_id_list = pc.distinct('params_id')
            n_params = len(params_id_list)
            if n_params != 1:
                print("     Note that there are {} models match the kwargs".format(n_params))
            return network
        except Exception as e:
            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
            logging.info("{}  {}  {}  {}  {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e))
            return False