Alberto Beiz

Mis pruebas y experimentos.

Contacta sin miedo:

Hola mundo con Fast.ai (PyTorch)

En Fast.ai han liberado un nuevo curso de Deep Learning para programadores, asi que al lío. Lo primero es realizar un clasificador de imágenes para no perder las buenas costumbres.

Objetivos

Antes de empezar

Al turrón

%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai.imports import *
from fastai.transforms import *
from fastai.conv_learner import *
from fastai.model import *
from fastai.dataset import *
from fastai.sgdr import *
from fastai.plots import *

La estructura es la misma que para el experimento con Keras. Carpetas train, valid, test y dentro de cada una carpetas dogs y cats. La librería fastai presupone esta estructura de carpetas cuando usamos su cargador de imágenes.

# Carpeta con los datos
PATH = "data/dogscats/"
# Tamaño de las imágenes
sz=224

Visualizando los datos

# Comprobamos el nombrado de los archivos
files = !ls {PATH}valid/cats | head 
files
['cat.10016.jpg',
 'cat.1001.jpg',
 'cat.10026.jpg',
 'cat.10048.jpg',
 'cat.10050.jpg',
 'cat.10064.jpg',
 'cat.10071.jpg',
 'cat.10091.jpg',
 'cat.10103.jpg',
 'cat.10104.jpg']
img = plt.imread(f'{PATH}valid/cats/{files[8]}')
plt.imshow(img);

png

img = plt.imread(f'{PATH}valid/dogs/dog.1001.jpg')
plt.imshow(img);

png

# Alto por ango y 3 canales de color
img.shape
(500, 347, 3)
img[:4,:4]
array([[[214, 209, 203],
        [202, 197, 191],
        [190, 185, 179],
        [186, 181, 175]],

       [[203, 198, 192],
        [190, 185, 179],
        [177, 172, 166],
        [171, 166, 160]],

       [[192, 187, 181],
        [177, 172, 166],
        [161, 156, 150],
        [153, 148, 142]],

       [[187, 182, 176],
        [170, 165, 159],
        [152, 147, 141],
        [142, 137, 131]]], dtype=uint8)

Elección del learning rate

# Cargamos un modelo preentrenado, en nuestro caso resnet34
arch=resnet34
data = ImageClassifierData.from_paths(PATH, tfms=tfms_from_model(arch, sz))
learn = ConvLearner.pretrained(arch, data, precompute=True)
# Buscamos el mejor learning rate
lrf = learn.lr_find()
 78%|███████▊  | 281/360 [00:06<00:01, 46.04it/s, loss=0.401] 
# Buscamos el mayor valor en el que la función de 
# pérdida todavía decrece de forma estable
learn.sched.plot()

png

En este caso tomaremos 0.01 porque 0.001 está demasiado cerca de la inestabilidad.

Entrenando la red neuronal

# Entrenamos el modelo
learn.fit(0.01, 3, cycle_len=1, cycle_mult=2)
epoch:   0, train_loss: 0.058019, val_loss: 0.023415, accuracy: 0.991699
epoch:   1, train_loss: 0.056788, val_loss: 0.025223, accuracy: 0.990723
epoch:   2, train_loss: 0.043230, val_loss: 0.023540, accuracy: 0.990723
epoch:   3, train_loss: 0.038578, val_loss: 0.022670, accuracy: 0.991211
epoch:   4, train_loss: 0.032743, val_loss: 0.021702, accuracy: 0.992188
epoch:   5, train_loss: 0.033366, val_loss: 0.031664, accuracy: 0.990234
epoch:   6, train_loss: 0.036556, val_loss: 0.021446, accuracy: 0.993652
# Calculamos las predicciones del set de validación
log_preds,y = learn.TTA()
probs = np.mean(np.exp(log_preds),0)
accuracy(probs, y)
0.99350000000000005

Un 99.35 % de acierto en menos de un minuto de entrenamiento, bruuutal.

Matriz de confusión

preds = np.argmax(probs, axis=1)
probs = probs[:,1]
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y, preds)
plot_confusion_matrix(cm, data.classes)

png

Visualizando errores

errors, errors_labels = zip(*[(i, preds[i]) for i in range(len(y)) if y[i] != preds[i]])
filenames = data.val_dl.dataset.fnames
errors_names = [filenames[i] for i in errors]
errors_names
['valid/cats/cat.12272.jpg',
 'valid/cats/cat.5583.jpg',
 'valid/cats/cat.2267.jpg',
 'valid/cats/cat.11735.jpg',
 'valid/cats/cat.10712.jpg',
 'valid/cats/cat.2893.jpg',
 'valid/cats/cat.11297.jpg',
 'valid/cats/cat.724.jpg',
 'valid/cats/cat.10107.jpg',
 'valid/dogs/dog.5336.jpg',
 'valid/dogs/dog.11186.jpg',
 'valid/dogs/dog.5231.jpg',
 'valid/dogs/dog.10103.jpg']
data.classes
['cats', 'dogs']
# Función para pintar una lista de imágenes
from matplotlib.gridspec import GridSpec
def plot_images(images, titles=None, cols=4):
    # Calculamos las filas necesarias
    rows = math.ceil(len(images)/cols)
    
    # Creamos la rejilla
    gs = GridSpec(rows, cols)
    f = plt.figure(figsize=(cols*2.5, rows*3))
    
    # Pintamos cada imagen
    for i in range(len(images)):
        # Añadimos la imagen a la rejilla
        s = f.add_subplot(gs[i//cols, i%cols])
        s.axis('Off')
        
        # Título si lo tiene
        if titles is not None:
            s.set_title(titles[i], fontsize=12)
            
        img = plt.imread(f'{PATH}{images[i]}')
        plt.imshow(img)
        
titles = [data.classes[label] for label in errors_labels]
plot_images(errors_names, titles)

png

En la mayoría de los errores las imágenes están poco definidas o aparece un perrete y un gatete.

Recapitulando

Fuentes