#####################################################################
# Exemple d'utilisation de SVM
# Utilisation des données des iris de Fisher
#####################################################################


#####################################################################
# import des librairies
#####################################################################
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd
import numpy as np
import sys

#
# Définition de chaines pour l'affichage
#
line = "=" * 50
separator = "\n" * 2


#####################################################################
# Données
#####################################################################

#
# On charge les données depuis sklearn
# On obtient un sklearn.utils.Bunch
#
iris = load_iris()

#
# Affiche les différents champs
#
iris_keys = iris.keys()
for k in iris_keys:
	print(line)
	print(k)
	print(iris[k])


#
# On transforme les données en DataFrame de pandas
# afin de pouvoir les manipuler plus simplement
#
df = pd.DataFrame(data = np.c_[ iris['data'], iris['target'] ], 
	columns= iris['feature_names'] + ['target'])

#
# On renomme les colonnes pour les manipuler plus
# simplement
# (inPlace=True évite de faire une copie du DataFrame)
#
df.rename( columns = {
	'sepal length (cm)': 'sepal_length' ,
	'sepal width (cm)': 'sepal_width', 
	'petal length (cm)': 'petal_length', 
	'petal width (cm)': 'petal_width'},
	inplace=True )

#
# On garde toutes les propriétés sauf la cible pour effectuer
# le calcul de la SVM
#
variables_du_modele = df.columns.drop(['target'])

#
# y est le vecteur de sortie à prédire
#
y = df.target

#
# X est une matrice des données en entrées
#
X = df[ variables_du_modele ]

#####################################################################
# Séparation des données
# - données d'apprentissage X_a, y_a
# - données de validation X_v, y_v
#####################################################################

#
# On crée deux jeux de données à partir de (X,y)
# - le jeu d'apprentissage X_a, y_a (70% des individus)
# - le jeu de test (X_v, y_v) qui permettra de vérifier 
#       la prédiction (30% des individus)
#
X_a, X_v, y_a, y_v = train_test_split(X, y, test_size=0.30)

#####################################################################
# Classifieur
#####################################################################

#
# création du classifieur
#
classifier = SVC(C=0.05, kernel='linear', gamma='auto')

#
# calcul sur ensemble d'apprentissage
#
print(line)
print(" Classification sur ensemble d'apprentissage")
print(line)
 
classifier.fit(X_a, y_a)

#####################################################################
# Prédiction
# sur le jeu de validation
#####################################################################

print(line)
print("Prédiction")
print(line)
y_v_prime = classifier.predict(X_v)
print("précision=", accuracy_score(y_v, y_v_prime) )
print( str(classifier) )

#####################################################################
# graphique des données
#####################################################################

#
# On travaille à présent sur toutes les données et on compare
# la prédiction 
#
y_prime =  classifier.predict(X)
fig = plt.figure()
axes = fig.add_subplot(111, projection='3d')
plt.title("Résultats Classification SVM")
axe1 = X['sepal_length']
axe2 = X['sepal_width']
axes.scatter(axe1, axe2, y.tolist(), c=y_prime, cmap=plt.cm.jet)

plt.show()


