Недавно я решил вернуться к освоению нейронных сетей. Меня привлекает их мощь и красота математического аппарата под капотом. Мое первоначальное знакомство с нейронками случилось пару лет назад. Я тогда потыкал классификацию, GANы, автоэнкодеры. Сейчас же я захотел заняться задачей сегментации. Решил изучать теорию и сразу же пробовать её на практике.
Ссылка на код: https://github.com/ArtemBoyarintsev/cell_segmentation
Датасет
В качестве практической задачи я выбрал задачу сегментации клеток. Датасет называется Colon Cancer Human CT 29. Исходный датасет содержит 41 исходных снимков клеток и 41 ответ - выделенные границы клеток на исходных снимках. Вот пример:
Может показаться, что 41 пример - это крайне маленький датасет. Например, известно, что датасет для классификации без предобученной модели требует хотя бы несколько десятков тысяч изображений. А в Image net вообще размечено около 14 млн изображений. Задача сегментации требует в принципе на порядки меньше изображений для обучения. Происходит это по двум причинам:
- Сетки для сегментации имеют меньше параметров
- Выходом в сегментации является тоже картинка. Таким образом, каждый нейрон во внутренних слоях получает feedback от тысяч выходных нейронов. В задаче же классификации выход только один.
В купе с требованием малого объема тренировочных данных можно обратиться к инструменту аугментации и раздуть тренировочный датасет как минимум до 10 тысяч картинок.
Суть задачи сегментации
Вообще задача сегментации это по сути попиксельная классификация. И она находит применение во многих отраслях. Например, в задачах self driven cars (автопилоты) есть исключительная необходимость понимать, где находится дорога, полосы для движения, другие автомобили, знаки дорожного движения, люди и прочее.
Из-за отсутствия полносвязных слоев в этих сетях сегментационная сеть работает с картинками любого размера. Это позволяет тренировать сеть на картинках одного размера, а использовать на картинках любого другого размера. Почему это происходит здесь, но не происходит в классификационных сетях? Потому что в классификационных сетях перед переходом из сверточных слоев в полносвязные происходит flatten - это когда матрица абсолютного любого размера трансформируется в вектор, и необходимо четко знать длину этого вектора.
Реализация
Любая нейронная сеть начальными слоями существенно уменьшает размерность исходной картинки. В задаче сегментации же необходимо выход нейронной сети привести к размеру исходной картинки. Для этого в основном используется механизм up-convolution.
Upconvolution
Что такое upconvolution? Это хитрый механизм для повышения размерности feature-maps. Он работает почти как обычная операция свертки. Только обработка кусочка картинки обычной сверткой порождает один пиксель.
Upconvolution же наоборот - использует в качестве входа один пиксель исходной картинки. Перемножает этот пиксель с ядром свертки. И, таким образом, порождает картинку размерности ядра свертки.
Итоговый результат после операции upconvolution будет следующий:
Если вдруг это остается непонятным - можно думать о upconvolution как о более умном подходе апсемплинга.
Архитектуры
Поскольку я почти ничего в этой теме не знал, то решил реализовать несколько известных архитектур в задаче сегментации. На сегодняшний день опробованы следующие архитектуры:
- FCN8 на базе VGG16
- UNET
FCN8
Начал я с первой попавшейся мне на глаза архитектуры из класса Fully Convolutional networks. В FCN-8 в качестве основной архитектуры (backbone) используется VGG-16. От неё выкидываются линйеные слои и остаются только конволюционные слои, то есть те, которые по сути выделяют признаки на картинке. Добавляется несколько скип-коннэкшенов, апконволюций и собственно все.
Кстати, это была в принципе моя первая архитектура в жизни, которую я реализовывал со схематичной картинки. У меня возник вопрос, как не ошибиться в размерах. Оказалось очень просто:
- Взял vgg16 в torchvisions (torchvisions.models.vgg16(pretrained=True))
- Достал из неё feature_extraction часть (та, что перед flatten и dense слоями)
- Затем добавлял слой за слоем, проверяя то, что размер выходной картинки после реализованного слоя совпадает с ожидаемым размером (который я брал из схемы выше)
Фух! Реализовал. Каков же результат нетренированной сети?
Тут у меня была довольно большая надежда - видно, что места с клетками на original test в целом помечаются как-то более насыщено, а пустые места остаются почти пустыми. Но это надежда скоро умерла.
Итоговый результат получился далеко от того, что ожидалось... Так что я переключился на UNET. Может быть я где-то ошибку допустил, а может архитектура для этой задачи не подходящая...
UNET
Революционной идеей UNET, которая выделила его среди прочих архитектур, было то, что столь огромное количество скип-коннекшенов позволило сделать границы объектов более точными. Нижележащие слои классифицирует объекты на картинке, а скип-коннекшены из других слоев позволяют уточнять, где классифицированный объект находится. Меня UNET удивил тем, что он уже без какого-либо обучения показывал интересные результаты. Пример того, что выдавала модель без тренировки, каких-либо претренировочных весов и особой инициализации весов показан ниже.
Спустя 32 итерации результат получился следующий.
Когда я это увидел в начале была радость, а потом небольшое недоумение. С одной стороны, задача сегментации решена корректна. Помним, что задача сегментации с формальной точки зрения - это попиксельная классификация, и здесь мы видим, что клетки, хорошо выделяются нейронной сетью (3-я колонка). Следовательно, нейронная сеть хорошо справилась с задачей. Но с другой стороны - как она могла верно сегментировать клетки, если в обучающей выборке были только выделенные границы (2-ая колонка)? Единственный ответ, который пришел мне на ум был следующей. Во-первых, это результат работы скип-коннекшенов, который по сути пробрасывает менее обработанную картинку вперед, минуя часть слоев.
Во-вторых, любая сегментационная нейронная сеть по сути представляет энкодер-декодер архитектуру и имеет внутри себя ботлнек. Таким образом, группа пикселей на исходном изображении где-то во внутренних слоях нейронной сети представлена “одним пикселем” (в кавычках, потому что по сути внутри сети это не пиксель, а просто выходы нейронов). И следовательно, при декодировании весь этот один пиксель может раскодироваться только во что-то одно: в моем случае, либо в клетку, либо не в клетку (фон).
Что было дальше?
Спустя пару недель после результатов, которые были на картинке 8, я кинулся на соревнования по задаче сегментации на платформе Kaggle под названием HuBMAP. Я взял написанный код обучения для задачи сегментации клеток и применил его в задаче HuBMAP. Результат оказался мягко говоря печальным. Я оказался на одном из последних мест среди тех участников, кто просто хоть что-то засамбител. Поэтому следующей итерацией стало использование так называемого segmentation loss. Его идея состоит в комбинации функции потерь из бинарной кросс-энтропии и добавления части из оптимизируемой метрики. В случае с оптимизацией dice-метрики (популярная метрика для задачи сегментации) это выглядит так:
Запустив считаться ноутбук с решением для соревнования, мне стало интересно, насколько бы отличались итоговые результаты решения этой учебной задачи сегментации. Удовлетворением этого интереса я и решил заняться, пока для задачи соревнования делать нечего. Результат оказался просто ошеломительным.
Уже после первого этапа работы сети, сеть показывала просто потрясающие результаты. Конечный результат работы сети так же показал великолепные результаты.
Ура! Сетка начала выдать именно те результаты, которые я ей показывал при обучении. Вы спросите меня: “А как же все те прелестные объяснения про скип-коннекшены и внутреннее сжатое представление пикселей?”. Я отвечу: “Не знаю :)”. Какое-то время назад я понял, что это в целом довольно распространенное явление объяснить что-то постфактум красиво и правдоподобно, но потом какой-то практический пример не укладывается в рамки приведенного объяснения :)
Почему же результат так сильно отличается? Дело в функции потерь, друзья! Когда я первый раз решал эту задачу, то использовал вообще L2 loss по незнанию. Таким образом, на собственных граблях я узнал, что L2 loss не очень применим к задаче классификации (коей является задача сегментации) :)
Заключение
На этом пока закончились мои эксперименты с задачей сегментации. В близжайшее время я планирую применить некие Feature Pyramid Network. Они призваны решить проблему с масштабом объектов на картинке. К примеру, если у вас тренировачный датасет состоит из кошечек, занимающих ровно пол картинки; а во время применения (inference) вы подадите на вход кошечку, занимающую только 1/8 части картинки, то вполне вероятно, что сетка либо ошибется, либо не будет уверена, что на картинке именно кошечка.
Делитесь вашими мыслями в комментариях! Всем добра! :)