Alberto Beiz

Mis pruebas y experimentos.

Contacta sin miedo:

Clasificador de imágenes con múltiples clases

Como todo en la vida no son perros o gatos y a veces una imágen puede pertenecer a varias clases, vamos a aprender como enfrentarnos a este problema. Para ello vamos a usar el dataset de la competición de Kaggle Planet, que consiste en etiquetar fotos satelitales del amazonas según su contenido.

Objetivos

Antes de empezar

Al turrón

# Helpers
%reload_ext autoreload
%autoreload 2
%matplotlib inline
# Imports externos
from fastai.imports import *
# Importamos la libreria
from fastai.transforms import *
from fastai.conv_learner import *
from fastai.model import *
from fastai.dataset import *
from fastai.sgdr import *
from fastai.plots import *
PATH = "data/planet/"
# He descomprimido las imagenes
# en train-jpg
ls {PATH}
models/                       test-jpg-additional.tar.7z    train-jpg/
sample_submission_v2.csv.zip  test_v2_file_mapping.csv.zip  train_v2.csv
test-jpg/                     tmp/

Cada imágen puede contener distintos tipos de terreno y de tiempo atmosférico.

list_paths = [f"{PATH}train-jpg/train_0.jpg", f"{PATH}train-jpg/train_1.jpg"]
titles=["haze primary", "agriculture clear primary water"]
plots_from_files(list_paths, titles=titles, maintitle="Multi-label classification")

png

import warnings
from sklearn.metrics import fbeta_score

# Función de evaluación de la competición
def f2(preds, targs, start=0.17, end=0.24, step=0.01):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        return max([fbeta_score(targs, (preds>th), 2, average='samples')
                    for th in np.arange(start,end,step)])
metrics=[f2]
f_model = resnet34
# Las clases de cada imagen nos las proporcionan
# en un csv
label_csv = f'{PATH}train_v2.csv'
n = len(list(open(label_csv)))-1

# Creamos el validation set sacando
# una parte de las imágenes
val_idxs = get_cv_idxs(n)
# Función para la carga de imágenes con data augmentation.
# aplicamos transforms_top_down para flipear
# las imágenes en horizontal y vertical
# y un poco de zoom
def get_data(sz):
    tfms = tfms_from_model(f_model, sz, aug_tfms=transforms_top_down, max_zoom=1.05)
    return ImageClassifierData.from_csv(PATH, 'train-jpg', label_csv, tfms=tfms,
                    suffix='.jpg', val_idxs=val_idxs, test_name='test-jpg')
data = get_data(256)
# Comprobamos que una imagen pertenece a
# varias categorías
x,y = next(iter(data.val_dl))
list(zip(data.classes, y[0]))
[('agriculture', 1.0),
 ('artisinal_mine', 0.0),
 ('bare_ground', 0.0),
 ('blooming', 0.0),
 ('blow_down', 0.0),
 ('clear', 1.0),
 ('cloudy', 0.0),
 ('conventional_mine', 0.0),
 ('cultivation', 0.0),
 ('habitation', 0.0),
 ('haze', 0.0),
 ('partly_cloudy', 0.0),
 ('primary', 1.0),
 ('road', 0.0),
 ('selective_logging', 0.0),
 ('slash_burn', 0.0),
 ('water', 1.0)]
plt.imshow(data.val_ds.denorm(to_np(x))[0]*1.4);

png

# Reducimos el tamaño para que las
# primeras pruebas sean más rápidas
sz=64
data = get_data(sz)
# Creamos la red neuronal convolucional
learn = ConvLearner.pretrained(f_model, data, metrics=metrics)
# Buscamos el learning rate adecuado
lrf=learn.lr_find()
learn.sched.plot()
epoch      trn_loss   val_loss   f2                          
    0      0.228332   0.388202   0.790217  

png

lr = 0.2

Ahora vamos a entrenar la red aumentando el tamaño de las imágenes poco a poco. Hacemos esto porque las imágenes en las que resnet ha sido pre-entrenada son muy distintas a las imágenes satelitales con las que vamos a trabajar.

Queremos que la red vaya poco a poco aprendiendo las características de nuestras imágenes a la vez que evitamos que llegue a producirse overfitting mediante el aumento de tamaño de las imágenes.

learn.fit(lr, 3, cycle_len=1, cycle_mult=2)
epoch      trn_loss   val_loss   f2                          
    0      0.14549    0.132691   0.883379  
    1      0.140758   0.126412   0.889713                    
    2      0.137548   0.124952   0.891458                    
    3      0.135616   0.124114   0.892314                    
    4      0.13127    0.121797   0.893941                    
    5      0.129936   0.120531   0.895511                    
    6      0.129624   0.120619   0.895055                    
lrs = np.array([lr/9,lr/3,lr])
learn.unfreeze()
learn.fit(lrs, 3, cycle_len=1, cycle_mult=2)
epoch      trn_loss   val_loss   f2                          
    0      0.113274   0.10409    0.914445  
    1      0.109907   0.100968   0.916191                    
    2      0.104249   0.097739   0.918716                    
    3      0.107126   0.098832   0.9183                      
    4      0.102634   0.096218   0.921197                    
    5      0.095631   0.093848   0.92325                      
    6      0.096058   0.093187   0.922863                     
learn.save(f'{sz}')
learn.load(f'{sz}')
sz = 128
learn.set_data(get_data(sz))
learn.freeze()
learn.fit(lr, 3, cycle_len=1, cycle_mult=2)
epoch      trn_loss   val_loss   f2                           
    0      0.097527   0.093736   0.921393  
    1      0.094811   0.093001   0.921954                     
    2      0.096117   0.092129   0.922715                     
    3      0.096227   0.092706   0.922573                     
    4      0.094526   0.091443   0.9222                       
    5      0.092762   0.091316   0.923078                     
    6      0.092021   0.090788   0.923549                     
learn.unfreeze()
learn.fit(lrs, 3, cycle_len=1, cycle_mult=2)
learn.save(f'{sz}')
epoch      trn_loss   val_loss   f2                           
    0      0.095568   0.087335   0.927648  
    1      0.092487   0.087872   0.92696                      
    2      0.089898   0.08599    0.928536                     
    3      0.091247   0.0887     0.927887                     
    4      0.08698    0.085667   0.928287                     
    5      0.085485   0.085197   0.929813                     
    6      0.082507   0.083924   0.93066                      
sz = 256
learn.set_data(get_data(sz))
learn.freeze()
learn.fit(lr, 3, cycle_len=1, cycle_mult=2)
epoch      trn_loss   val_loss   f2                           
    0      0.091713   0.089376   0.924707  
    1      0.087906   0.088589   0.925556                     
    2      0.088206   0.088269   0.925404                     
    3      0.088546   0.088203   0.927165                     
    4      0.089512   0.08705    0.927212                     
    5      0.086984   0.087283   0.926282                     
    6      0.090308   0.087335   0.926741                     
learn.unfreeze()
learn.fit(lrs, 3, cycle_len=1, cycle_mult=2)
learn.save(f'{sz}')
epoch      trn_loss   val_loss   f2                           
    0      0.084487   0.082897   0.931791  
    1      0.088547   0.084135   0.93144                      
    2      0.08152    0.082441   0.932349                     
    3      0.087541   0.084951   0.929821                     
    4      0.082985   0.082579   0.93194                      
    5      0.080027   0.081958   0.932279                     
    6      0.079555   0.081633   0.932907                     

Todavía podríamos entrenar y mejorar un poco más, no hemos llegado al overfitting. Pero no quiero gastar más tiempo, me doy por satisfecho.

multi_preds, y = learn.TTA()
preds = np.mean(multi_preds, 0)
# El ganador de la competición consiguió 93.4
# Con un poco de ajuste y entrenamiento conseguiríamos
# superar ese valor
f2(preds,y)
0.9309057293028824
# Podemos ver un ejemplo de predicción
plt.imshow(data.val_ds.denorm(to_np(x))[0]*1.4);

png

# Comparamos la predicción con el valor real
[i for i in zip(data.classes, y[0], preds[0])]
[('agriculture', 1.0, 0.70372528),
 ('artisinal_mine', 0.0, 0.00066884671),
 ('bare_ground', 0.0, 0.011150284),
 ('blooming', 0.0, 0.00070670969),
 ('blow_down', 0.0, 0.001093367),
 ('clear', 1.0, 0.99982405),
 ('cloudy', 0.0, 4.6757359e-05),
 ('conventional_mine', 0.0, 0.00028788339),
 ('cultivation', 0.0, 0.35036865),
 ('habitation', 0.0, 0.017872976),
 ('haze', 0.0, 0.00020892653),
 ('partly_cloudy', 0.0, 4.8972048e-05),
 ('primary', 1.0, 0.99973029),
 ('road', 0.0, 0.21921349),
 ('selective_logging', 0.0, 0.0047851792),
 ('slash_burn', 0.0, 0.01710622),
 ('water', 1.0, 0.78968823)]

Recapitulando

Fuentes