Skip to content
Snippets Groups Projects
Select Git revision
  • ac0d9ba5844a4bee16ad39919be6b7299d738381
  • main default protected
2 results

perceptron.py

Blame
  • perceptron.py 3.50 KiB
    # Author : Capt Thibault , Souza Luz Juliano
    # Date : 31.10.2023
    # Project : Perceptron
    # Description : Ce fichier représente notre travail pour le tp du perceptron
    import numpy as np
    import pandas as pd
    from matplotlib import pyplot as plt
    
    
    def upd_weights(wi_old, learning, t, y, xi):
        """
        Mettre à jour les poids
        :param wi_old: Les anciens poids
        :param learning: Le taux d'apprentissage ( qui contrôle la vitesse de convergence)
        :param t: La valeur cible
        :param y: La sortie actuelle du modèle pour l'exemple donné
        :param xi: Les caractéristiques de l'exemple d'apprentissage
        :return: Les poids mis à jour
        """
        return wi_old + learning * (t - y) * y * (1 - y) * xi
    
    
    def sigmoide(x):
        """
        Calcule la fonction sigmoïde (fonction d'activation) pour une valeur x
        :param x: La valeur d'entrée.
        :return: La sigmoide en fonction de x.
        """
        return 1 / (1 + np.exp(-x))
    
    
    if __name__ == '__main__':
        dataset = pd.read_csv("Data/student-data-train.csv", header=0)
        # Normalisation des colonnes
        dataset['norm_grade_1'] = (dataset['grade_1'] - dataset['grade_1'].mean()) / dataset['grade_1'].std()
        dataset['norm_grade_2'] = (dataset['grade_2'] - dataset['grade_2'].mean()) / dataset['grade_2'].std()
    
        # Extraction des données
        X = dataset[['norm_grade_1', 'norm_grade_2']].values
        y = dataset.iloc[:, 0].values
        num_features = X.shape[1]
        # ------------------- Paramètres ------------------
        learning_rate = 1e-2  # Taux d'apprentissage
        max_iterations = 2000
        # Initialisation aléatoire des poids (+1 pour le biais)
        poids = np.random.rand(num_features + 1)
        print("Poids initiaux:", poids)
        # Boucle d'apprentissage
        for iteration in range(max_iterations):
            total_error = 0
            for i in range(len(X)):
                # Ajout du biais (X0 = 1)
                input_data = np.insert(X[i], 0, 1)
                cible = y[i]
                # Calcul de la sortie du réseau
                sortie = sigmoide(np.dot(poids, input_data))
                # Mise à jour des poids
                poids = upd_weights(poids, learning_rate, cible, sortie, input_data)
                # Calcul de l'erreur quadratique
                total_error += (cible - sortie) ** 2
    
            # Affichage de l'erreur à chaque itération (à enlever pour de meilleures performances)
            print(f"Iteration {iteration + 1}: Erreur = {total_error}")
    
        # Calcul du taux de classification correcte
        correct_classifications = 0
        for i in range(len(X)):
            input_data = np.insert(X[i], 0, 1)
            cible = y[i]
            sortie = sigmoide(np.dot(poids, input_data))
            pred = 1 if sortie >= 0.5 else 0
            if pred == cible:
                correct_classifications += 1
    
        accuracy = correct_classifications / len(X)
        print(f"Taux de classifications correctes: {accuracy * 100}%")
    
        # Affichage de la droite de séparation des classes
        w1, w2, b = poids[1], poids[2], poids[0]
        pente = -w1 / w2
        intercept = -b / w2
        print(f"Droite de séparation: y = {pente}x + {intercept}")
    
        # Tracer la droite de séparation (diagonale)
        plt.figure()
        plt.scatter(X[y == 0][:, 0], X[y == 0][:, 1], color='red', label='Classe 0')
        plt.scatter(X[y == 1][:, 0], X[y == 1][:, 1], color='blue', label='Classe 1')
        plt.plot([-2, 2], [-2 * pente + intercept, 2 * pente + intercept], color='green', label='Droite de séparation')
        plt.title('Données avec la droite de séparation')
        plt.xlabel('Norm_Grade_1')
        plt.ylabel('Norm_Grade_2')
        plt.legend()
        plt.show()