コード例 #1
0
    def paintEvent(self, e):
        #data = self.itemData(self.currentIndex())
        data = NWM.intToStyle(NWM.LINE_STYLE_LIST[self.currentIndex()])
        line_color = self.network_index.model().createIndex(self.network_index.row(),
                                                            NWM.NET_COLOR).data()
        #line_color = self.network_index.model().data(self.network_index.row(),
        #                                             WLM.NET_COLOR)
        p = QStylePainter(self)
        p.setPen(self.palette().color(QPalette.Text))

        opt = QStyleOptionComboBox()
        self.initStyleOption(opt)
        p.drawComplexControl(QStyle.CC_ComboBox, opt)

        painter = QPainter(self)
        painter.save()

        rect = p.style().subElementRect(QStyle.SE_ComboBoxFocusRect, opt, self)
        rect.adjust(+5, 0, -5, 0)

        pen = QPen()
        pen.setColor(line_color)
        pen.setWidth(3)
        pen.setStyle(data)
        painter.setPen(pen)

        middle = (rect.bottom() + rect.top()) / 2

        painter.drawLine(rect.left(), middle, rect.right(), middle)
        painter.restore()
コード例 #2
0
    def paint(self, painter, option, index):
        data = NWM.intToStyle(int(index.data()))
        line_color = self.network_index.model().createIndex(self.network_index.row(),
                                                            NWM.NET_COLOR).data()
        #line_color = self.network_model.data(index.row(), WLM.NET_COLOR)
        painter.save()

        rect = option.rect
        rect.adjust(+5, 0, -5, 0)

        pen = QPen()
        pen.setColor(line_color)
        pen.setWidth(3)
        pen.setStyle(data)
        painter.setPen(pen)

        middle = (rect.bottom() + rect.top()) / 2

        painter.drawLine(rect.left(), middle, rect.right(), middle)
        painter.restore()
コード例 #3
0
import sys

from digits import DigitData
from model import NetworkModel

if __name__ == '__main__':
    model_file = sys.argv[1]
    test_data = DigitData.from_json(sys.argv[2])
    model = NetworkModel.load(model_file)
    print('Accuracy: %s' % model.evaluate(test_data, progress=True))
コード例 #4
0
import activation
from digits import DigitData
from model import NetworkModel

EPOCHS = 60

def train(model, train_data, test_data, model_file, epochs=EPOCHS):
    try:
        for i in range(epochs):
            print('Epoch %s' % i)
            train_data.shuffle()
            for datum in tqdm(train_data.data[:10000]):
                model.train(datum.features(), datum.label_vec)
            print('Evaluating...')
            print('Accuracy: %s' % model.evaluate(test_data, progress=True))
    except KeyboardInterrupt:
        pass
    finally:
        model.save(model_file)


# python train.py data/train.json data/weights.json data/model.json
if __name__ == '__main__':
    train_data = DigitData.from_json(sys.argv[1])
    test_data = DigitData.from_json(sys.argv[2])
    model_file = sys.argv[3]
    model = NetworkModel(train_data.num_features(), 30, train_data.num_labels(),
                         activation_fns=[activation.Sigmoid, activation.Sigmoid])
    train(model, train_data, test_data, model_file)
コード例 #5
0
ファイル: main.py プロジェクト: aki85/osero-ai
from osero import Osero

env = Osero(board_size=6)
exit()
from logger import EpisodeLogger
from model import NetworkModel
from keras.optimizers import Adam

from rl.agents.dqn import DQNAgent
from rl.policy import EpsGreedyQPolicy
from rl.memory import SequentialMemory

env = Osero(board_size=4)
nb_actions = env.action_space.n

model = NetworkModel.create_simple_nn(env)
print(model.summary())

memory = SequentialMemory(limit=10000, window_length=1)

policy = EpsGreedyQPolicy(eps=0.1)
dqn = DQNAgent(model=model,
               nb_actions=nb_actions,
               memory=memory,
               nb_steps_warmup=100,
               target_model_update=1e-2,
               policy=policy)
dqn.compile(Adam(lr=1e-3), metrics=['mae'])
# dqn.load_weights('results/190621/10000.h5')
history = dqn.fit(env,
                  nb_steps=10000,