示例#1
0
def graphic(net: Network, env: Env):
    window = pygame.display.set_mode((env.width, env.height))
    pygame.mouse.set_visible(False)

    while True:
        window.fill((0, 0, 0))
        pygame.draw.line(window, WHITE, env.goal_line_left.convert_cords(env),
                         env.goal_line_right.convert_cords(env))
        pygame.draw.line(window, GREEN, env.left_post.convert_cords(env),
                         env.right_post.convert_cords(env))
        pygame.draw.circle(window, GREEN, env.left_post.convert_cords(env), 2)
        pygame.draw.circle(window, GREEN, env.right_post.convert_cords(env), 2)
        pygame.draw.circle(window, GREEN, env.center_goal.convert_cords(env),
                           env.penalty_area_r, 1)

        t.sleep(0.01)

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()

        if any(pygame.mouse.get_pressed()):
            atk = pygame.mouse.get_pos()
            pygame.draw.line(window, LIGHT_RED, atk,
                             env.left_post.convert_cords(env))
            pygame.draw.line(window, LIGHT_RED, atk,
                             env.right_post.convert_cords(env))
            pygame.draw.circle(window, RED, atk, 2)

            gk_x, gk_y = net.predict(
                np.array(
                    [[atk[0] / (env.width / 2),
                      atk[1] / (env.height / 2)]]))[0][0]
            gk_x, gk_y = int(gk_x * (env.width // 2)), int(gk_y *
                                                           (env.height // 2))
            gk = Point(gk_x, gk_y)

            pygame.draw.circle(window, GREEN, gk.convert_cords(env), 2)

        pygame.display.flip()
示例#2
0
#Test on plotting quadratic curve

import numpy as np
from net import Network, train
import matplotlib.pyplot as plt
import matplotlib

matplotlib.use('Agg')

iterations = 500

X = np.random.rand(1000, 1) * 2 * np.pi
y = (1 + np.sin(X)) / 2

net = Network(1, 1)
net, loss = train(net, X, y, iterations, 10)

plt.scatter(X, net.predict(X))
plt.xlabel('X')
plt.ylabel('Predicted y: sin(x)')
plt.savefig('sin2.png')

plt.clf()
plt.plot(np.arange(iterations), loss)
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.title('Sin Approximation Loss')
plt.savefig('sinloss2.png')
import matplotlib
import matplotlib.pyplot as plt

from utils import getData
from net import Network,train

matplotlib.use('Agg')

#Flower
iterations = 1000

X,y = getData(2)
netF = Network(2,3)
netF,loss = train(netF,X,y,iterations,100)

preds = netF.predict(X)
plt.imshow(preds.reshape(133,140,3))
plt.title('Flower')
plt.savefig('Flower2.png')

plt.clf()
plt.plot(np.arange(iterations),loss)
plt.ylabel('Loss')
plt.xlabel('Iterations')
plt.title('Flower Loss')
plt.savefig('Flowerloss2.png')

#Lincoln
iterations = 1000
X,y = getData(1)
netL = Network(2,1)
示例#4
0
class App(QWidget):
    DRAW_NONE = 0
    DRAW_ADD = 1
    DRAW_REMOVE = 2

    def __init__(self, files):
        super().__init__()
        self.network = Network()
        self.loader = DataLoader(files)
        self.subject_view = None
        self.num_trained = 0
        self.previous = []
        self.draw_mode = App.DRAW_NONE
        self.setGeometry(100, 100, 500, 500)
        self.path = None

        self.next()

    def next(self):
        self.save()
        
        subject = self.loader.next()
        if not subject:
            return

        if subject.segm is None:
            if self.num_trained < 1:
                subject.segm = np.zeros(subject.img.shape, dtype=np.uint8)
            else:
                subject.segm = (255*self.network.predict(subject.img)).astype(np.uint8)

        self.subject_view = SubjectView(subject)
        self.slice_index = len(self.subject_view.slices)//2

        self.update()

    def save(self):
        if not self.subject_view:
            return

        self.subject_view.save()

        subject = self.subject_view.subject

        amax = np.max(subject.segm.reshape(subject.segm.shape[0], -1), axis=1)
        ids = np.where(amax > 0)[0]
        if len(ids) == 0:
            return # No segmentation

        zrange = (ids[0], ids[-1]+1)

        self.previous.append((
            subject.img, subject.segm, zrange
        ))

        imgs = []
        segms = []

        for p in self.previous[-5:]:
            for z in range(p[2][0], p[2][1]):
                imgs.append(p[0][z:z+1].astype(np.float32))
                segms.append((p[1][z:z+1] == 255).astype(np.long))

        self.network.fit(imgs, segms)
        self.num_trained += 1

    def clear(self):
        if not self.subject_view:
            return

        self.subject_view.clear()
        self.update()


    def paintEvent(self, event):
        if not self.subject_view:
            return
        
        painter = QPainter(self)

        pixmap = QPixmap(self.subject_view.slices[self.slice_index])
        painter.drawPixmap(self.rect(), pixmap)

        pixmap = QPixmap(self.subject_view.segms[self.slice_index])

        painter2 = QPainter(pixmap)
        painter2.setCompositionMode(QPainter.CompositionMode_SourceIn)
        painter2.fillRect(pixmap.rect(), QColor('blue'))
        painter2.end()

        painter.setOpacity(0.5)
        painter.drawPixmap(self.rect(), pixmap)
        painter.setOpacity(1)

        if self.path:
            color = QColor('white')
            if self.draw_mode == App.DRAW_ADD:
                color = QColor('green')
            elif self.draw_mode == App.DRAW_REMOVE:
                color = QColor('red')

            painter.setPen(QPen(color, 2))
            painter.drawPath(self.path)

    def wheelEvent(self,event):
        d = int(event.angleDelta().y()) // 120

        self.slice_index = max(0, min(self.slice_index + d, len(self.subject_view.slices)-1))
        self.update()

    def mousePressEvent(self, event):
        mod = event.modifiers()
        if mod & Qt.ShiftModifier:
            self.draw_mode = App.DRAW_REMOVE
        else:
            self.draw_mode = App.DRAW_ADD

        self.path = QPainterPath()
        self.path.moveTo(event.pos())
    
    def mouseMoveEvent(self, event):
        self.path.lineTo(event.pos())
        self.update()
    
    def mouseReleaseEvent(self, event):
        self.path.closeSubpath()
        
        segm = self.subject_view.segms[self.slice_index]

        painter = QPainter(segm)
        painter.setBrush(Qt.white)

        # Resample path to image space
        transform = QTransform()
        transform.scale(
            segm.width() / self.rect().width(),
            segm.height() / self.rect().height()
        )
        painter.setTransform(transform)
        
        if self.draw_mode == App.DRAW_REMOVE:
            painter.setCompositionMode(QPainter.CompositionMode_Clear)
        
        painter.fillPath(self.path, QColor('white'))
        painter.end()

        self.path = None
        self.draw_mode = App.DRAW_NONE
        self.update()

    def keyPressEvent(self, event):
        key = event.key()
        if key == Qt.Key_Escape:
            self.app.quit()
        elif key == Qt.Key_S:
            self.save()
        elif key == Qt.Key_N:
            self.next()
        elif key == Qt.Key_C:
            self.clear()