from PyQt5.QtCore import Qt, QDir, QFile, QPoint, QRect, QSize
from PyQt5.QtWidgets import QWidget, QPushButton, QLineEdit, QLabel, QGridLayout
from PyQt5.QtGui import QImage, QImageWriter, QPainter, QPen, qRgb, QPixmap

import os
#import matplotlib.pyplot as plt

# Import datasets, classifiers and performance metrics
from sklearn import datasets, svm, metrics
from PIL import Image
import numpy as np
import math
from sklearn.metrics import accuracy_score
from sklearn.externals import joblib

class Button(QWidget):
    def __init__(self,parent=None, label="Button", on_click=None):
        super(Button,self).__init__(parent)
        button = QPushButton(label, self)
        button.move(220,0)
        button.clicked.connect(on_click)
        grid = QGridLayout()
        grid.addWidget(button,1,0,1,2)
        self.setLayout(grid) 
        
class ComboWidget(QWidget):
    def __init__(self, parent=None, 
                 text_on_label="Label", 
                 text_on_button="button", 
                 on_click=None, readOnly=False):
        super(ComboWidget,self).__init__(parent)

        button = QPushButton(text_on_button, self)
        lbl = QLabel(self)
        lbl.setText(text_on_label)
        self.qle = QLineEdit(self)
        self.qle.setReadOnly(readOnly)
        self.qle.setAlignment(Qt.AlignCenter)
        button.clicked.connect(on_click)
        grid = QGridLayout()
        grid.addWidget(lbl,1,0)
        grid.addWidget(self.qle,1,1)
        grid.addWidget(button,1,2)
        self.setLayout(grid) 
    
    def text(self):
        return self.qle.text()
    
    def set_text(self, text=""):
        self.qle.setText(text)
    
class Board(QWidget):
    def __init__(self,parent=None):
        super(Board,self).__init__(parent)
        self.setAttribute(Qt.WA_StaticContents)
        self.penWidth = 50
        self.writing = False
        self.penColor = Qt.black
        self.image = QImage()
        self.lastPoint = QPoint()
        self.setFixedSize(512,512)

    def saveImage(self,label):
        import os
        visibleImage = self.image
        self.resizeImage(visibleImage,self.size())
        if QDir('Image').exists() == False:
            QDir().mkdir('Image')

        label = 'Image/'+label
        if QDir(label).exists() == False:
            QDir().mkdir(label)
        cnt = 0
        maxf = 0
        for file in os.listdir(label):
            if file.endswith('.png'):
                cnt += 1
                tmpN = int(file[:-4:])
                maxf = max(maxf,tmpN)
        visibleImage.save('tmp.png')
        pixmap = QPixmap('tmp.png').scaled(64,64)
        pixmap.save(label+'/'+str(max(cnt,maxf)+1)+'.png')
        QFile.remove('tmp.png')
        
    def saveImageTest(self):
        import os
        visibleImage = self.image
        self.resizeImage(visibleImage,self.size())
        if QDir('Test').exists() == False:
            QDir().mkdir('Test')

        visibleImage.save('tmp.png')
        pixmap = QPixmap('tmp.png').scaled(64,64)
        pixmap.save('Test/temp.png')
        QFile.remove('tmp.png')

    def paintEvent(self, event):
        painter = QPainter(self)
        dirtyRect = event.rect()
        painter.drawImage(dirtyRect, self.image, dirtyRect)

    def clearImage(self):
        self.image.fill(qRgb(255,255,255))
        self.update()

    def mousePressEvent(self,event):
        if event.button() == Qt.LeftButton:
            self.lastPoint = event.pos()
            self.writing = True

    def mouseMoveEvent(self,event):
        if (event.buttons() & Qt.LeftButton) and self.writing:
            self.drawLineTo(event.pos())

    def mouseReleaseEvent(self,event):
        if event.button() == Qt.LeftButton and self.writing:
            self.drawLineTo(event.pos())
            self.writing = False

    def resizeEvent(self,event):
        self.resizeImage(self.image,self.size())
        self.update()
        super(Board,self).resizeEvent(event)
    
    def drawLineTo(self, endPoint):
        painter = QPainter(self.image)
        painter.setPen(QPen(self.penColor, self.penWidth, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin))
        painter.drawLine(self.lastPoint, endPoint)
        self.writing = True

        rad = self.penWidth/2 + 2
        self.update(QRect(self.lastPoint, endPoint).normalized().adjusted(-rad, -rad, +rad, +rad))
        self.lastPoint = QPoint(endPoint)

    def resizeImage(self,image, newSize):
        if image.size() == newSize:
            return
        newImage = QImage(newSize, QImage.Format_RGB32)
        newImage.fill(qRgb(255,255,255))
        painter = QPainter(newImage)
        painter.drawImage(QPoint(0,0), image)
        self.image = newImage

    def data_prep_test(self):   
        self.saveImageTest()
        im = Image.open('Test/temp.png')
        iar = np.array(im)
        
        width, height = im.size
        test_data = np.zeros(width*height)
        
        test_data = np.zeros(64*64)
        count = 0
        for i in range(len(iar)):
            for j in range(len(iar[i])):
                test_data[count] = math.floor(((np.sum(iar[i][j][:3]/765))))
                #print(test_data[count])
                count+=1
        return [test_data]  

    def data_prep_train(self):
        print('preporcessing..')
        folder_name = 'Image/'
        list_of_label = os.listdir(folder_name)
        digit_data = list()
        digit_label = list()
        for label in list_of_label:
            label_path = folder_name+str(label)+'/'
            list_of_data = os.listdir(label_path)
           #print(list_of_data)
            for data in list_of_data:
                im = Image.open(label_path+'/'+data)
                arr_im = np.array(im)
                width, height = im.size
                gen_data = np.zeros(width*height)
                count = 0
                for i in range(len(arr_im)):
                    for j in range(len(arr_im[i])):
                        #ทำ Normalize ด้วยการ (R+G+B) หาร 765 มาจาก (255+255+255)
                        gen_data[count] = math.floor(((np.sum(arr_im[i][j][:3]/765))))
                        count+=1
                digit_data.append( gen_data )
                digit_label.append( label )     
        digit_data = np.array(digit_data)  #ข้อมูล
        digit_label = np.array(digit_label) # label ของข้อมูล  
        return digit_data,digit_label         
    
    
    def evaluation(self, digit_data, digit_label, classifier):
        predicted = classifier.predict(digit_data)
        print("Classification report for classifier %s:\n%s\n"
              % (classifier, metrics.classification_report(digit_label, predicted)))
        print("Confusion matrix:\n%s" % metrics.confusion_matrix(digit_label, predicted))
        print("\naccuracy = ",accuracy_score(digit_label, predicted))