Alberto Beiz

Mis pruebas y experimentos.

Contacta sin miedo:

Visualizando datos y predicciones con Keras y Matplotlib

Vamos a ver funciones simples para ver imágenes y buscar resultados. De manera que podamos entender mejor en qué acierta y qué falla nuestro modelo.

Visualizando los datos de entrenamiento

import keras
import utils
import math
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
(x_train, y_train), (x_test, y_test) = utils.load_f_mnist_as_lists()
# Visualizar una imagen
img = x_train[0]
pixels = img.reshape((28, 28))
plt.imshow(pixels, cmap='gray')
plt.show()

png

# Lista de clases de Fashion MNIST
f_mnist_categories = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
                      'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# Visualizar una imagen con su clase
img = x_train[0]
pixels = img.reshape((28, 28))

# Buscamos la posición del 1 en el hot encoding
max_cat = y_train[0].argmax()
# Y la sacamos de la lista de clases
plt.title(f_mnist_categories[max_cat])

plt.imshow(pixels, cmap='gray')
plt.show()

png

# Función para pintar una lista de imágenes
def plot_mnist(images, titles=None, cols=4):
    # Calculamos las filas necesarias
    rows = math.ceil(len(images)/cols)
    
    # Creamos la rejilla
    gs = GridSpec(rows, cols)
    f = plt.figure(figsize=(cols*2, rows*2.5))
    
    # Pintamos cada imagen
    for i in range(len(images)):
        # Añadimos la imagen a la rejilla
        s = f.add_subplot(gs[i//cols, i%cols])
        # Quitamos los ejes
        s.axis('Off')
        
        # Título si lo tiene
        if titles is not None:
            s.set_title(titles[i], fontsize=12)
            
        img = images[i]
        pixels = img.reshape((28, 28))
        plt.imshow(pixels, cmap='gray')
# Ahora podemos pintar series de imágenes
titles = [f_mnist_categories[classes.argmax()] for classes in y_train[0:8]]
plot_mnist(x_train[0:8], titles)

png

Visualizando las predicciones

# Usamos las predicciones del post anterior
y_predict = utils.load_array('predicciones.dat')
# Podemos ver la predicción y su probabilidad
titles = [f_mnist_categories[classes.argmax()]
          +' \n '+
          str(max(classes)) for classes in y_predict[0:8]]
plot_mnist(x_test[0:8], titles)

png

Buscando los fallos

# para ver los fallos, calculamos los índices de
# las prediciones que no coinciden con su clase
y_errors_ind = [ind for ind, classes in enumerate(y_predict) 
                    if classes.argmax() != y_test[ind].argmax()]
# Extraemos los elementos con error
y_errors = y_predict[y_errors_ind]
x_errors = x_test[y_errors_ind]
y_real = y_test[y_errors_ind]

# Y los pintamos
titles = ['Pred: ' + f_mnist_categories[classes.argmax()]
          +' \n '+ 
          'Real: ' + f_mnist_categories[y_real[ind].argmax()]
          for ind, classes in enumerate(y_errors[0:8])]
plot_mnist(x_errors[0:8], titles)

png

Matriz de confusión

Por último vamos a representar las predicciones en un gráfico que facilita visualizar los aciertos y errores de una pasada, una matriz de confusión.

# Necesitamos pasar nuestras predicciones a
# números de nuevo, revirtiendo el one hot encoding1
y_predict_classes = [classes.argmax() for classes in y_predict]
y_test_classes = [classes.argmax() for classes in y_test]
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test_classes, y_predict_classes)
matrix = utils.plot_confusion_matrix(cm, f_mnist_categories)

png

Podemos ver por ejemplo que nuestro modelo no distingue bien entre camisas y camisetas, zapatos y sandalias o entre pantalones y vestidos. Aunque en este caso los errores son bastante lógicos este tipo de gráficos puede descubrirnos errores inesperados.

Conclusión

Con estas simples herramientas podemos empezar a estudiar los resultados de nuestro modelo de una forma mucho más interesante y profunda que un simple numerito. De manera que podemos focalizar nuestros esfuerzos en los casos que lo requieren, en vez de pegar tiros al aire.

Fuentes