Exemple #1
0
    def restore_params(self, sess):
        # For no pre-trained model link

        # logging.info("Restore pre-trained parameters")
        # vgg19_npy_path = os.path.join('models', 'vgg19.npy')
        # if not os.path.isfile(vgg19_npy_path):
        #     print("Please download vgg19.npy from : https://github.com/machrisaa/tensorflow-vgg")
        #     exit()
        # npz = np.load(vgg19_npy_path, encoding='latin1').item()

        # For existing pre-trained model link
        logging.info("Restore pre-trained parameters")
        maybe_download_and_extract(
            'vgg19.npy',
            'models',
            'https://media.githubusercontent.com/media/tensorlayer/pretrained-models/master/models/',
            expected_bytes=574670860)
        vgg19_npy_path = os.path.join('models', 'vgg19.npy')
        npz = np.load(vgg19_npy_path, encoding='latin1').item()

        params = []
        for val in sorted(npz.items()):
            W = np.asarray(val[1][0])
            b = np.asarray(val[1][1])
            print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
            params.extend([W, b])

        print("Restoring model from npz file")
        assign_params(sess, params, self.net)
        del params
Exemple #2
0
 def restore_params(self, sess, path='models'):
     logging.info("Restore pre-trained parameters")
     maybe_download_and_extract(
         'mobilenet.npz', path, 'https://github.com/tensorlayer/pretrained-models/raw/master/models/',
         expected_bytes=25600116
     )  # ls -al
     params = load_npz(name=os.path.join(path, 'mobilenet.npz'))
     assign_params(sess, params[:len(self.net.all_params)], self.net)
     del params
Exemple #3
0
    def restore_params(self, sess, weights):
        from tensorlayer.files import assign_params
        logging.info("Restore pre-trained parameters")
        npz = np.load(weights)
        params = []
        for val in sorted(npz.items()):
            print("  Loading params %s" % str(val[1].shape))
            params.append(val[1])
            if len(self.vgg.all_params) == len(params):
                break

        assign_params(sess, params, self.vgg)
        del params
Exemple #4
0
    def restore_params(self, md, sess):
        # we have to customize the restore param function a bit
        from tensorlayer.files import assign_params
        logging.info("Restore pre-trained parameters")
        npz = np.load(
            os.path.join('../data/pretrain_model', 'vgg16_weights.npz'))

        params = []
        for val in sorted(npz.items()):
            print("  Loading params %s" % str(val[1].shape))
            params.append(val[1])
            if len(md.all_params) == len(params):
                break

        assign_params(sess, params, md.net)
        del params
Exemple #5
0
    def restore_params(self, sess):
        logging.info("Restore pre-trained parameters")
        maybe_download_and_extract(
            'vgg16_weights.npz', 'models', 'http://www.cs.toronto.edu/~frossard/vgg16/', expected_bytes=553436134
        )
        npz = np.load(os.path.join('models', 'vgg16_weights.npz'))

        params = []
        for val in sorted(npz.items()):
            logging.info("  Loading params %s" % str(val[1].shape))
            params.append(val[1])
            if len(self.all_params) == len(params):
                break

        assign_params(sess, params, self.net)
        del params
Exemple #6
0
    sess.run(tf.global_variables_initializer())
    maybe_download_and_extract('vgg16_weights.npz',
                               'models',
                               'http://www.cs.toronto.edu/~frossard/vgg16/',
                               expected_bytes=553436134)
    npz = np.load(os.path.join('models', 'vgg16_weights.npz'))
    params = []
    idx = 0
    for val in sorted(npz.items()):
        idx = idx + 1
        print("  Loading params %s" % (str(val[1].shape)))
        params.append(val[1])
        if idx == 17:
            break
    assign_params(sess, params, net_work['block4'])
    for epoch in range(train_epoch):
        idx = 0
        for batch_imgs, batch_ann in tl.iterate.minibatches(laod_imgs,
                                                            load_ann,
                                                            batch_size,
                                                            shuffle=False):
            g_img, g_ann = getTrainData(batch_imgs, batch_ann)
            gt_class, gt_location, gt_positives, gt_negatives = generate_groundtruth_data(
                g_ann)
            # print("g_img",g_img)
            # print("gt_positives",np.sum(gt_positives,axis=1))
            # print("gt_positives", np.sum(gt_negatives, axis=1))
            # print("pre_cls", sess.run(pre_cls, feed_dict={imageinput: g_img, groundtruth_class: gt_class,
            #                                               groundtruth_location: gt_location,
            #                                               groundtruth_positives: gt_positives,