Beispiel #1
0
def init_weights(weights):
    print(17.72)
    ops = []

    ca_tuples_w = []
    ca_tuples = []
    for w in weights:
        initializer = w.initializer
        for input in initializer.inputs:
            if "_cai_" in input.name:
                ca_tuples_w.append (w)
                ca_tuples.append ( (w.shape.as_list(), w.dtype.as_numpy_dtype) )
                break
        else:
            ops.append (initializer)

    print(17.73)
    if len(ops) != 0:
        nn.tf_sess.run (ops)

    print(17.74)
    if len(ca_tuples) != 0:
        nn.batch_set_value( [*zip(ca_tuples_w, nn.initializers.ca.generate_batch (ca_tuples))] )
    
    print(17.75)
Beispiel #2
0
    def set_weights(self, new_weights):
        weights = self.get_weights()
        if len(weights) != len(new_weights):
            raise ValueError('len of lists mismatch')

        tuples = []
        for w, new_w in zip(weights, new_weights):

            if len(w.shape) != new_w.shape:
                new_w = new_w.reshape(w.shape)

            tuples.append((w, new_w))

        nn.batch_set_value(tuples)
Beispiel #3
0
    def load_weights(self, filename):
        """
        returns True if file exists
        """
        filepath = Path(filename)
        if filepath.exists():
            result = True
            d_dumped = filepath.read_bytes()
            d = pickle.loads(d_dumped)
        else:
            return False

        weights = self.get_weights()

        if self.name is None:
            raise Exception("name must be defined.")

        try:
            tuples = []
            for w in weights:
                w_name_split = w.name.split('/')
                if self.name != w_name_split[0]:
                    raise Exception("weight first name != Saveable.name")

                sub_w_name = "/".join(w_name_split[1:])

                w_val = d.get(sub_w_name, None)

                if w_val is None:
                    #io.log_err(f"Weight {w.name} was not loaded from file {filename}")
                    tuples.append((w, w.initializer))
                else:
                    w_val = np.reshape(w_val, w.shape.as_list())
                    tuples.append((w, w_val))

            nn.batch_set_value(tuples)
        except:
            return False

        return True