In this post, we will walk through the process of creating a deep segmentation model using a pytorch library SMP (Segmentation Models Pytorch). To verify that the modeling pipeline is working and the parameters are set up properly (as an initial experiment), we will also train the model using a small subset of the cloud satellite data in the Kaggle competition "Understanding the Cloud".

Highlights:

  • A quick deployment of a Fully Convolutional Neural Network (FCNN) using SMP.
  • Using real satellite imagery on cloud formation as training data.
  • A verification of the model set-up by training on a small subset of data.

First, let's import the necessary libraries and set up the environment.

In [1]:
import tensorflow as tf
tf.test.gpu_device_name()
Out[1]:
''
In [2]:
from google.colab import drive

drive.mount('/content/drive', force_remount=True)
Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive
In [3]:
PATH = '/content/drive/My Drive/kaggle_cloud/data'
%cd /content/drive/My\ Drive/kaggle_cloud/src_cloudflower2/CloudFlower2

# If we need to check out the latest code, uncomment below: 
!git pull
/content/drive/My Drive/kaggle_cloud/src_cloudflower2/CloudFlower2
Already up to date.
In [4]:
# The necessary libraries and the version numbers are specified in the requirement.txt
!pip install -r requirements.txt
Collecting pandas==1.0.3
[?25l  Downloading https://files.pythonhosted.org/packages/bb/71/8f53bdbcbc67c912b888b40def255767e475402e9df64050019149b1a943/pandas-1.0.3-cp36-cp36m-manylinux1_x86_64.whl (10.0MB)
     |████████████████████████████████| 10.0MB 2.6MB/s 
[?25hCollecting numpy==1.18.4
[?25l  Downloading https://files.pythonhosted.org/packages/03/27/e35e7c6e6a52fab9fcc64fc2b20c6b516eba930bb02b10ace3b38200d3ab/numpy-1.18.4-cp36-cp36m-manylinux1_x86_64.whl (20.2MB)
     |████████████████████████████████| 20.2MB 1.7MB/s 
[?25hCollecting albumentations==0.3.2
[?25l  Downloading https://files.pythonhosted.org/packages/ad/34/e1da4fab7282d732a6cef827c7e5fb1efa1f02c3ba1bff4a0ace2daf6639/albumentations-0.3.2.tar.gz (79kB)
     |████████████████████████████████| 81kB 8.9MB/s 
[?25hCollecting segmentation-models-pytorch==0.1.0
[?25l  Downloading https://files.pythonhosted.org/packages/70/88/763a25dfe076a9f30f33466b1bd0f2d31b915b88d4cb4481fe4043cf26b4/segmentation_models_pytorch-0.1.0-py3-none-any.whl (42kB)
     |████████████████████████████████| 51kB 5.9MB/s 
[?25hCollecting catalyst==20.05.1
[?25l  Downloading https://files.pythonhosted.org/packages/61/58/03dd689feee9089450d977b3be7b4579f097d236532afaeb64202d18fb72/catalyst-20.5.1-py2.py3-none-any.whl (362kB)
     |████████████████████████████████| 368kB 46.7MB/s 
[?25hRequirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas==1.0.3->-r requirements.txt (line 1)) (2.8.1)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas==1.0.3->-r requirements.txt (line 1)) (2018.9)
Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from albumentations==0.3.2->-r requirements.txt (line 3)) (1.4.1)
Collecting opencv-python-headless
[?25l  Downloading https://files.pythonhosted.org/packages/17/e4/a98a3c3098ea55b6ae193a1cd19a221dc3c1bde87a36db5550addc879d36/opencv_python_headless-4.3.0.36-cp36-cp36m-manylinux2014_x86_64.whl (36.4MB)
     |████████████████████████████████| 36.4MB 115kB/s 
[?25hCollecting imgaug<0.2.7,>=0.2.5
[?25l  Downloading https://files.pythonhosted.org/packages/ad/2e/748dbb7bb52ec8667098bae9b585f448569ae520031932687761165419a2/imgaug-0.2.6.tar.gz (631kB)
     |████████████████████████████████| 634kB 48.2MB/s 
[?25hRequirement already satisfied: PyYAML in /usr/local/lib/python3.6/dist-packages (from albumentations==0.3.2->-r requirements.txt (line 3)) (3.13)
Requirement already satisfied: torchvision>=0.3.0 in /usr/local/lib/python3.6/dist-packages (from segmentation-models-pytorch==0.1.0->-r requirements.txt (line 4)) (0.6.1+cu101)
Collecting efficientnet-pytorch>=0.5.1
  Downloading https://files.pythonhosted.org/packages/b8/cb/0309a6e3d404862ae4bc017f89645cf150ac94c14c88ef81d215c8e52925/efficientnet_pytorch-0.6.3.tar.gz
Collecting pretrainedmodels==0.7.4
[?25l  Downloading https://files.pythonhosted.org/packages/84/0e/be6a0e58447ac16c938799d49bfb5fb7a80ac35e137547fc6cee2c08c4cf/pretrainedmodels-0.7.4.tar.gz (58kB)
     |████████████████████████████████| 61kB 6.0MB/s 
[?25hRequirement already satisfied: tensorboard>=1.14.0 in /usr/local/lib/python3.6/dist-packages (from catalyst==20.05.1->-r requirements.txt (line 5)) (2.2.2)
Collecting tensorboardX
[?25l  Downloading https://files.pythonhosted.org/packages/af/0c/4f41bcd45db376e6fe5c619c01100e9b7531c55791b7244815bac6eac32c/tensorboardX-2.1-py2.py3-none-any.whl (308kB)
     |████████████████████████████████| 317kB 45.3MB/s 
[?25hRequirement already satisfied: plotly>=4.1.0 in /usr/local/lib/python3.6/dist-packages (from catalyst==20.05.1->-r requirements.txt (line 5)) (4.4.1)
Requirement already satisfied: scikit-learn>=0.20 in /usr/local/lib/python3.6/dist-packages (from catalyst==20.05.1->-r requirements.txt (line 5)) (0.22.2.post1)
Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from catalyst==20.05.1->-r requirements.txt (line 5)) (20.4)
Requirement already satisfied: tqdm>=4.33.0 in /usr/local/lib/python3.6/dist-packages (from catalyst==20.05.1->-r requirements.txt (line 5)) (4.41.1)
Requirement already satisfied: ipython in /usr/local/lib/python3.6/dist-packages (from catalyst==20.05.1->-r requirements.txt (line 5)) (5.5.0)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from catalyst==20.05.1->-r requirements.txt (line 5)) (3.2.2)
Requirement already satisfied: torch>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from catalyst==20.05.1->-r requirements.txt (line 5)) (1.5.1+cu101)
Collecting crc32c>=1.7
[?25l  Downloading https://files.pythonhosted.org/packages/33/ef/63dafd9e92fc6d03c7c5db893261d1304f8e67f187851eb486ede95bbec3/crc32c-2.0.1-cp36-cp36m-manylinux2010_x86_64.whl (42kB)
     |████████████████████████████████| 51kB 5.5MB/s 
[?25hCollecting deprecation
  Downloading https://files.pythonhosted.org/packages/02/c3/253a89ee03fc9b9682f1541728eb66db7db22148cd94f89ab22528cd1e1b/deprecation-2.1.0-py2.py3-none-any.whl
Collecting GitPython>=3.1.1
[?25l  Downloading https://files.pythonhosted.org/packages/f9/1e/a45320cab182bf1c8656107b3d4c042e659742822fc6bff150d769a984dd/GitPython-3.1.7-py3-none-any.whl (158kB)
     |████████████████████████████████| 163kB 35.9MB/s 
[?25hRequirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2.6.1->pandas==1.0.3->-r requirements.txt (line 1)) (1.15.0)
Requirement already satisfied: scikit-image>=0.11.0 in /usr/local/lib/python3.6/dist-packages (from imgaug<0.2.7,>=0.2.5->albumentations==0.3.2->-r requirements.txt (line 3)) (0.16.2)
Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision>=0.3.0->segmentation-models-pytorch==0.1.0->-r requirements.txt (line 4)) (7.0.0)
Collecting munch
  Downloading https://files.pythonhosted.org/packages/cc/ab/85d8da5c9a45e072301beb37ad7f833cd344e04c817d97e0cc75681d248f/munch-2.5.0-py2.py3-none-any.whl
Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (1.17.2)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (49.1.0)
Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (2.23.0)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (1.7.0)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (1.0.1)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (3.2.2)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (0.4.1)
Requirement already satisfied: wheel>=0.26; python_version >= "3" in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (0.34.2)
Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (1.30.0)
Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (3.12.2)
Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (0.9.0)
Requirement already satisfied: retrying>=1.3.3 in /usr/local/lib/python3.6/dist-packages (from plotly>=4.1.0->catalyst==20.05.1->-r requirements.txt (line 5)) (1.3.3)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn>=0.20->catalyst==20.05.1->-r requirements.txt (line 5)) (0.16.0)
Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->catalyst==20.05.1->-r requirements.txt (line 5)) (2.4.7)
Requirement already satisfied: pygments in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst==20.05.1->-r requirements.txt (line 5)) (2.1.3)
Requirement already satisfied: pexpect; sys_platform != "win32" in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst==20.05.1->-r requirements.txt (line 5)) (4.8.0)
Requirement already satisfied: simplegeneric>0.8 in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst==20.05.1->-r requirements.txt (line 5)) (0.8.1)
Requirement already satisfied: decorator in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst==20.05.1->-r requirements.txt (line 5)) (4.4.2)
Requirement already satisfied: traitlets>=4.2 in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst==20.05.1->-r requirements.txt (line 5)) (4.3.3)
Requirement already satisfied: prompt-toolkit<2.0.0,>=1.0.4 in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst==20.05.1->-r requirements.txt (line 5)) (1.0.18)
Requirement already satisfied: pickleshare in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst==20.05.1->-r requirements.txt (line 5)) (0.7.5)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->catalyst==20.05.1->-r requirements.txt (line 5)) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->catalyst==20.05.1->-r requirements.txt (line 5)) (1.2.0)
Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch>=1.1.0->catalyst==20.05.1->-r requirements.txt (line 5)) (0.16.0)
Collecting gitdb<5,>=4.0.1
[?25l  Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)
     |████████████████████████████████| 71kB 7.7MB/s 
[?25hRequirement already satisfied: imageio>=2.3.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.3.2->-r requirements.txt (line 3)) (2.4.1)
Requirement already satisfied: PyWavelets>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.3.2->-r requirements.txt (line 3)) (1.1.1)
Requirement already satisfied: networkx>=2.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.3.2->-r requirements.txt (line 3)) (2.4)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (0.2.8)
Requirement already satisfied: rsa<5,>=3.1.4; python_version >= "3" in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (4.6)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (4.1.1)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (3.0.4)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (2020.6.20)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (2.10)
Requirement already satisfied: importlib-metadata; python_version < "3.8" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (1.7.0)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (1.3.0)
Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.6/dist-packages (from pexpect; sys_platform != "win32"->ipython->catalyst==20.05.1->-r requirements.txt (line 5)) (0.6.0)
Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.6/dist-packages (from traitlets>=4.2->ipython->catalyst==20.05.1->-r requirements.txt (line 5)) (0.2.0)
Requirement already satisfied: wcwidth in /usr/local/lib/python3.6/dist-packages (from prompt-toolkit<2.0.0,>=1.0.4->ipython->catalyst==20.05.1->-r requirements.txt (line 5)) (0.2.5)
Collecting smmap<4,>=3.0.1
  Downloading https://files.pythonhosted.org/packages/b0/9a/4d409a6234eb940e6a78dfdfc66156e7522262f5f2fecca07dc55915952d/smmap-3.0.4-py2.py3-none-any.whl
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.6/dist-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (0.4.8)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < "3.8"->markdown>=2.6.8->tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (3.1.0)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.14.0->catalyst==20.05.1->-r requirements.txt (line 5)) (3.1.0)
Building wheels for collected packages: albumentations, imgaug, efficientnet-pytorch, pretrainedmodels
  Building wheel for albumentations (setup.py) ... [?25l[?25hdone
  Created wheel for albumentations: filename=albumentations-0.3.2-cp36-none-any.whl size=51063 sha256=9bf20f3d33a1a9843b95d763f127aa3205d0426275b9df1760ebf315a5939e4b
  Stored in directory: /root/.cache/pip/wheels/4c/74/a9/b8cfb94bcf1a5d7ea53a6b522bcd372b23b64595b7328e4f3f
  Building wheel for imgaug (setup.py) ... [?25l[?25hdone
  Created wheel for imgaug: filename=imgaug-0.2.6-cp36-none-any.whl size=654020 sha256=df8c0353612f3cd73a40b4bfec32d64b43650384b1a993f900ce6721c6d8a5fe
  Stored in directory: /root/.cache/pip/wheels/97/ec/48/0d25896c417b715af6236dbcef8f0bed136a1a5e52972fc6d0
  Building wheel for efficientnet-pytorch (setup.py) ... [?25l[?25hdone
  Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.6.3-cp36-none-any.whl size=12422 sha256=8a60f38174cc0051f3d8dbd71972b80af45cce6b938582db8769b9c364f7acfd
  Stored in directory: /root/.cache/pip/wheels/42/1e/a9/2a578ba9ad04e776e80bf0f70d8a7f4c29ec0718b92d8f6ccd
  Building wheel for pretrainedmodels (setup.py) ... [?25l[?25hdone
  Created wheel for pretrainedmodels: filename=pretrainedmodels-0.7.4-cp36-none-any.whl size=60962 sha256=d8678c3202cc1fef3147c775babf2076b97d7b618efe8d24bdc67a1f570a2112
  Stored in directory: /root/.cache/pip/wheels/69/df/63/62583c096289713f22db605aa2334de5b591d59861a02c2ecd
Successfully built albumentations imgaug efficientnet-pytorch pretrainedmodels
ERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible.
Installing collected packages: numpy, pandas, opencv-python-headless, imgaug, albumentations, efficientnet-pytorch, munch, pretrainedmodels, segmentation-models-pytorch, tensorboardX, crc32c, deprecation, smmap, gitdb, GitPython, catalyst
  Found existing installation: numpy 1.18.5
    Uninstalling numpy-1.18.5:
      Successfully uninstalled numpy-1.18.5
  Found existing installation: pandas 1.0.5
    Uninstalling pandas-1.0.5:
      Successfully uninstalled pandas-1.0.5
  Found existing installation: imgaug 0.2.9
    Uninstalling imgaug-0.2.9:
      Successfully uninstalled imgaug-0.2.9
  Found existing installation: albumentations 0.1.12
    Uninstalling albumentations-0.1.12:
      Successfully uninstalled albumentations-0.1.12
Successfully installed GitPython-3.1.7 albumentations-0.3.2 catalyst-20.5.1 crc32c-2.0.1 deprecation-2.1.0 efficientnet-pytorch-0.6.3 gitdb-4.0.5 imgaug-0.2.6 munch-2.5.0 numpy-1.18.4 opencv-python-headless-4.3.0.36 pandas-1.0.3 pretrainedmodels-0.7.4 segmentation-models-pytorch-0.1.0 smmap-3.0.4 tensorboardX-2.1
In [5]:
import cv2
import matplotlib.pyplot as plt
import albumentations as albu
from albumentations import torch as AT

import numpy as np
import pandas as pd

import torch
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torch import nn
In [6]:
################################################################
# Set a seed for the random number generator (so the results can be reproduced)
################################################################
torch.manual_seed(0)
np.random.seed(0)

1. Load Data

In this part, we will loan and take a quick view of the cloud image dataset. To demonstrate the neural network has a correct definition of its architecture and the optimization algorithm is working properly, we will only use fifty images for training and validation.

In [7]:
from utils.dataset_helper import read_train_df, split_image_dataset

PATH = '/content/drive/My Drive/kaggle_cloud/data'

csvfile = f'{PATH}/train.csv'
data_df = read_train_df(csvfile)
print(data_df.shape)
data_df.head() 
(22184, 4)
Out[7]:
Image_Label EncodedPixels image_name label
0 0011165.jpg_Fish 264918 937 266318 937 267718 937 269118 937 27... 0011165.jpg Fish
1 0011165.jpg_Flower 1355565 1002 1356965 1002 1358365 1002 1359765... 0011165.jpg Flower
2 0011165.jpg_Gravel NaN 0011165.jpg Gravel
3 0011165.jpg_Sugar NaN 0011165.jpg Sugar
4 002be4f.jpg_Fish 233813 878 235213 878 236613 878 238010 881 23... 002be4f.jpg Fish
In [8]:
FOLDER = 'train_images'
sample_file = 'fStafea4f4.jpg'

CLOUD_LABELS = ['Fish', 'Flower', 'Gravel', 'Sugar']
In [9]:
torch.cuda.empty_cache()
In [10]:
################################################################
# Set up the data size in the experiment using max_n_images
################################################################
# Since we are testing the model for a correct implementation, using only
# a small subset of the data. 
df_train, df_valid = split_image_dataset(data_df, train_ratio=0.75, max_n_images=50)
print("#rows for the avaialble dataset:", data_df.shape)
print("#rows for the training data: ", df_train.shape)
print("#rows for the validation data: ", df_valid.shape)
#rows for the avaialble dataset: (22184, 4)
#rows for the training data:  (148, 4)
#rows for the validation data:  (52, 4)

2. Model Definition: UNET

In this part, we will set up a UNET model using Segmentation Models in Pytorch (smp). For detailed steps, please refer to: https://github.com/qubvel/segmentation_models.pytorch#start.

For UNET, we provide two dataset abstractions: (1) CloudDataset with images and masks; (2) DataLoader supporting data batching and shuffling. They will form the inputs for the FCNN model we will build in the next sections.

A tutorial to write DataLoader in pytorch is availabel: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html?highlight=dataloader.

a. The dataset class and a lean-version preprocessing

For simplicity, we scaled the pixel values into a range of [0, 1] by dividing them by 255.

NOTE: Image augmentation is skipped for this experiment and the augmentation function only involves a resizing. No further normalization is applied.

In [11]:
from utils.cloud_dataset import CloudDataset

def simple_preprocessing(image, mask):
  # A simple preprocessing function to re-scale the image into the range of [0,1]

  gv_max, gv_min = np.max(image), np.min(image)
  image = (image - gv_min) / (gv_max - gv_min)

  image = image.transpose(2, 0, 1).astype('float32')
  mask = mask.transpose(2, 0, 1).astype('float32')

  preprocessed = {'image': image, 'mask': mask}

  return preprocessed


def test_simple_preprocessing():
  image = 100. * np.random.rand(4,4,3)
  mask = np.ones((4,4,1))
  rst = simple_preprocessing(image, mask)
  print(np.max(rst['image']))
  print(np.min(rst['image']))
  # assert(-1e-6 < np.max(rst['image']) - 1.0 < 1e-6)

test_simple_preprocessing()
1.0
0.0
In [18]:
import segmentation_models_pytorch as smp

from utils.dataset_helper import get_training_augmentation
from utils.dataset_helper import get_validation_augmentation
from utils.dataset_helper import get_preprocessing

# ENCODER = 'resnet50'
# ENCODER = 'inceptionresnetv2'
ENCODER = 'resnet18'
ENCODER_WEIGHTS = 'imagenet'
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

image_folder = f'{PATH}/{FOLDER}'

# get_validation_augmentation() only has a resize transformation: 
train_dataset = CloudDataset(df_train, image_folder, 
                             transforms = get_validation_augmentation(),
                             preprocessing=simple_preprocessing)

valid_dataset = CloudDataset(df_valid, image_folder, 
                             transforms = get_validation_augmentation(),
                             preprocessing=simple_preprocessing)

#####################################
# Data: 
#####################################

num_workers = 0  # number of subprocesses to load the data
bs = 16  # batch_size

train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True, 
                          num_workers=num_workers)
valid_loader = DataLoader(valid_dataset, batch_size=bs, shuffle=True, 
                          num_workers=num_workers)
loaders = {'train': train_loader, 
           'valid': valid_loader}
In [19]:
sample_ind = 0
# The original data sample: 
data_ori = train_dataset.get_data_by_index(sample_ind)
# Preprocessed data: 
data_pp = train_dataset[sample_ind]

assert(data_ori[0].ndim==3 and data_ori[0].shape[0]==3)
assert(data_pp[0].ndim==3 and data_pp[0].shape[0]==3)
In [20]:
from utils.dataset_helper import viz_image_mask_arrays
viz_image_mask_arrays(data_ori[0], data_ori[1])
In [21]:
viz_image_mask_arrays(data_pp[0], data_pp[1])

Here since we only applied a resize operation on the original image, we don't observe any change except the size (changed into 320 x 640). We can verify that the mask and the images are resized consistently.

In [ ]:
print("The size of the training dataset:", len(train_dataset))
print("The shape of input image data: ", data_pp[0].shape)
print("The shape of the input labels:", data_pp[1].shape)

print()
print("The max/min of the original image:", np.max(data_ori[0]), np.min(data_ori[0]))
print("The max/min of the preprocessed image:", np.max(data_pp[0]), np.min(data_pp[0]))
print("The mean/std of the preprocessed image:", np.mean(data_pp[0]), np.std(data_pp[0]))
The size of the training dataset: 37
The shape of input image data:  (3, 320, 640)
The shape of the input labels: (4, 320, 640)

The max/min of the original image: 255 0
The max/min of the preprocessed image: 1.0 0.0
The mean/std of the preprocessed image: 0.31877252 0.23482749
In [ ]:
#####################################
# UNET model using smp:
# Reference: https://github.com/qubvel/segmentation_models.pytorch#models
#####################################

# TODO: We put the activation in the loss function, so no need to repeat it here. 
ACTIVATION = None
# ACTIVATION = 'sigmoid'
model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=4, 
    activation=ACTIVATION,
)

b. Loss Function

In [ ]:
from loss.dice_loss import BCEDiceLoss, DiceLoss, dice_no_threshold
criterion = BCEDiceLoss(activation='sigmoid')

3. Model Training

In [ ]:
#####################################
# Optimizer settings: 
#####################################
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR

optimizer = torch.optim.Adam([
    {'params': model.decoder.parameters(), 'lr': 1e-2}, 
    {'params': model.encoder.parameters(), 'lr': 1e-3},  
])
# Other values of learning rates which have been tested:
# (1e-3, 1e-4), (5e-3, 5e-4), (5e-2, 1e-2)

# opt_level = 'O1'
# model.cuda()
# model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)

# Using scheduler to apply learning reduce on Plateau: 
scheduler = ReduceLROnPlateau(optimizer, factor=0.15, patience=2)
In [ ]:
from tqdm.auto import tqdm as tq
In [ ]:
# for i, p in enumerate(model.encoder.parameters()):
#   print(p.data.max(), p.data.min())

print("BEFORE training:")
params = list(model.encoder.parameters())
print("#parameter groups:", len(params))
print("The max/min of the params[0]:", params[0].data.max(), params[0].data.min())
print("The mean/std of the params[0]:", params[0].data.mean(), params[0].data.std())
print(params[0].grad is None)
BEFORE training:
#parameter groups: 60
The max/min of the params[0]: tensor(1.0165) tensor(-0.8434)
The mean/std of the params[0]: tensor(2.9420e-05) tensor(0.1297)
True
In [ ]:
torch.cuda.empty_cache()
In [ ]:
from models.runner import Runner

from datetime import datetime

rst_path = "/content/drive/My Drive/kaggle_cloud/run_20200617"

cld_runner = Runner(model, criterion)
train_rst = cld_runner.train(train_loader, valid_loader, 
                optimizer, scheduler, 
                valid_score_fn = dice_no_threshold, 
                n_epochs = 50, train_on_gpu=True, verbose=True, rst_path = rst_path)
train_loss_list, valid_loss_list, dice_score_list = train_rst[0], train_rst[1], train_rst[2]
In [ ]:
params = list(model.encoder.parameters())
print("AFTER training:")
print("")
print("#parameter groups:", len(params))
print("The max/min of the params[0]:", params[0].data.max(), params[0].data.min())
print("The mean/std of the params[0]:", params[0].data.mean(), params[0].data.std())
print(params[0].grad is None)
AFTER training:

#parameter groups: 60
The max/min of the params[0]: tensor(1.0188, device='cuda:0') tensor(-0.8490, device='cuda:0')
The mean/std of the params[0]: tensor(-0.0002, device='cuda:0') tensor(0.1299, device='cuda:0')
False
In [ ]:
import matplotlib.pyplot as plt

plt.figure(figsize=(8,5))
plt.plot(train_loss_list,  marker='o', label="Training Loss")
plt.ylabel('loss', fontsize=22)
plt.legend()
plt.show()
In [ ]:
plt.figure(figsize=(8, 5))

plt.plot(valid_loss_list[5:],  marker='o', label="Validation Loss")
plt.ylabel('loss', fontsize=22)
plt.legend()
plt.show()
In [ ]:
plt.figure(figsize=(8,5))
plt.plot(dice_score_list)
plt.ylabel('Dice score')
plt.show()

Summary:

In this simple experiment, we have achieved the following:

  • We have built up a learning pipeline to preprocess the image data, define a segmentation model using a high-level library, and train the model with a small dataset to verify the set-up.
  • With a small training dataset and an proper optimimzer (learning rates), we can reduce the training error effectively, although there is still room to improve regarding the loss.
In [ ]: