Skip to content
Snippets Groups Projects
Select Git revision
  • 9537fc0f9bf38c70aead00f4e6e9f5c355ad6bc8
  • master default protected
2 results

AccountDatabase.java

Blame
  • decisiontree-student.py 1.50 KiB
    # Import des bibliothèques nécessaires
    import pandas as pd
    from sklearn.model_selection import train_test_split
    from sklearn.tree import DecisionTreeClassifier
    from sklearn.metrics import accuracy_score
    from sklearn.tree import plot_tree
    import matplotlib.pyplot as plt
    
    # Chargement des données depuis un fichier CSV
    file_path = './Data/student-data-train.csv'
    df = pd.read_csv(file_path)
    
    # Séparation des features et de la cible
    X = df.drop('success', axis=1)
    y = df['success']
    
    # Division des données en ensembles d'entraînement et de test
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)  # Vous pouvez ajuster la taille du test si nécessaire
    
    # Construction de l'arbre de décision avec des paramètres spécifiques
    # Vous pouvez jouer avec les valeurs de min_samples_leaf et max_depth
    clf = DecisionTreeClassifier(min_samples_leaf=5, max_depth=3)
    clf.fit(X_train, y_train)
    
    # Prédictions sur les ensembles d'entraînement et de test
    y_train_pred = clf.predict(X_train)
    y_test_pred = clf.predict(X_test)
    
    # Mesure du taux de classification correcte
    train_accuracy = accuracy_score(y_train, y_train_pred)
    test_accuracy = accuracy_score(y_test, y_test_pred)
    
    print(f'Taux de classification correcte (Entraînement): {train_accuracy:.2f}')
    print(f'Taux de classification correcte (Test): {test_accuracy:.2f}')
    
    # Visualisation de l'arbre de décision
    plt.figure(figsize=(12, 8))
    plot_tree(clf, filled=True, feature_names=X.columns, class_names=['0', '1'], rounded=True)
    plt.show()