コード例 #1
0
    def on_epoch_end(self, epoch, logs=None):
        #get current validation mae
        if cf.dataset == 'PPG_Dalia':
            current = logs.get("val_mean_absolute_error")
        elif cf.dataset == 'copy_memory':
            current = logs.get("val_loss")
        elif cf.dataset == 'poly_music':
            current = logs.get("val_loss")

        #compare with previous best one
        if np.less(current, self.best):
            self.best = current

            # Record the best model if current results is better (less).
            names = [
                weight.name for layer in model.layers
                for weight in layer.weights
            ]
            weights = model.get_weights()
            for name, weight in zip(names, weights):
                if re.search('learned_conv2d.+_?[0-9]/gamma', name):
                    self.gamma[name] = weight
                    self.gamma[name] = np.array(
                        self.gamma[name] > cf.threshold, dtype=bool)
                    self.gamma[name] = utils.dil_fact(self.gamma[name],
                                                      op='mul')

            print("New best MAE, update file. \n")
            print(self.gamma)
            utils.save_dil_fact(cf.saving_path, self.gamma)
コード例 #2
0
def train_gammas_Char_PTB(model, epochs, n_words, X_train, X_valid, X_test):
    train_losses = list()
    valid_losses = list()
    test_losses = list()

    best_v_l = np.Inf
    patience = 30
    wait = 0
    gamma = dict()

    for ep in range(epochs):
        tr_l, v_l, t_l = train_Char_PTB(model, ep, n_words, X_train, X_valid,
                                        X_test, train_losses, valid_losses,
                                        test_losses)

        names = [
            weight.name for layer in model.layers for weight in layer.weights
        ]
        weights = model.get_weights()
        for name, weight in zip(names, weights):
            if re.search('learned_conv2d.+_?[0-9]/gamma', name):
                gamma[name] = weight

        print(gamma)

        # early_stop + export_structure
        if np.less(v_l, best_v_l):
            if v_l >= 2:
                best_v_l = v_l
            wait = 0

            # Record the best model if current results is better (less).
            names = [
                weight.name for layer in model.layers
                for weight in layer.weights
            ]
            weights = model.get_weights()
            for name, weight in zip(names, weights):
                if re.search('learned_conv2d.+_?[0-9]/gamma', name):
                    #gamma[name] = weight
                    gamma[name] = np.array(gamma[name] > cf.threshold,
                                           dtype=bool)
                    gamma[name] = utils.dil_fact(gamma[name], op='mul')

            print("New best MAE, update file. \n")
            print(gamma)
            utils.save_dil_fact(cf.saving_path, gamma)

        else:
            wait += 1
            print("Val loss did not improve from {}".format(best_v_l))
            print("Iter for at least {} epochs".format(patience - wait))

        if wait >= patience:
            print("Early Stop")
            break

        ep += 1
コード例 #3
0
    def on_epoch_end(self, epoch, logs=None):
        #get current validation mae
        if cf.dataset == 'PPG_Dalia':
            current = logs.get(val_mae)
            l = 1
            h = 0
            wait = 0
        elif cf.dataset == 'Nottingham' or cf.dataset == 'JSB_Chorales':
            current = logs.get("val_loss")
            l = 1
            h = 0
        elif cf.dataset == 'SeqMNIST' or cf.dataset == 'PerMNIST':
            current = logs.get("val_accuracy")
            l = 0
            h = 1
            wait = 20
        else:
            print("{} is not supported".format(cf.dataset))
            sys.exit()

        if self.i > wait:
            # compare with previous best one
            if bool(np.less(current, self.best) * l) ^ \
                bool((current > self.best) * h):
                self.best = current

                # Record the best model if current results is better.
                names = [
                    weight.name for layer in self.model.layers
                    for weight in layer.weights
                ]
                weights = self.model.get_weights()
                for name, weight in zip(names, weights):
                    if re.search('learned_conv2d.+_?[0-9]/gamma', name):
                        self.gamma[name] = weight
                        self.gamma[name] = np.array(
                            self.gamma[name] > cf.threshold, dtype=bool)
                        self.gamma[name] = utils.dil_fact(self.gamma[name],
                                                          op='mul')
                    elif re.search('weight_norm.+_?[0-9]/gamma', name):
                        self.gamma[name] = weight
                        self.gamma[name] = np.array(
                            self.gamma[name] > cf.threshold, dtype=bool)
                        self.gamma[name] = utils.dil_fact(self.gamma[name],
                                                          op='mul')
                print("New best model, update file. \n")
                print(self.gamma)
                utils.save_dil_fact(cf.saving_path, self.gamma)
        else:
            self.i += 1
コード例 #4
0
                # Record the best model if current results is better (less).
                names = [
                    weight.name for layer in model.layers
                    for weight in layer.weights
                ]
                weights = model.get_weights()
                for name, weight in zip(names, weights):
                    if re.search('learned_conv2d.+_?[0-9]/gamma', name):
                        gamma[name] = weight
                        gamma[name] = np.array(gamma[name] > cf.threshold,
                                               dtype=bool)
                        gamma[name] = utils.dil_fact(gamma[name], op='mul')

                print("New best MAE, update file. \n")
                print(gamma)
                utils.save_dil_fact(cf.saving_path, gamma)

            else:
                wait += 1
                print("Val loss did not improve from {}".format(best_v_l))
                print("Iter for at least {} epochs".format(patience - wait))

        if wait >= patience:
            print("Early Stop")
            break

        ep += 1

model.save_weights(cf.saving_path + 'autodil/test2_trained_weights_warmup' +
                   str(cf.warmup) + '.h5')