Решающие деревья

Предполагается изучение литературы 😮

Решающие деревья - строго математично

Популярно с примером кода

Задание на реализацию дерева решений из курса МГУ

Дополнительно, если интересно:

Основные понятия машинного обучения

Деревья принятия решений (Decision Tree)

Деревья бывают разные - черный, белые, красные. И все их одинаково хочется запрограммировать. Но возникает резонный вопрос - а где же их использовать и зачем они вообще нужны? Мы с вами уже бегло познакомились с обычным бинарным деревом поиска, когда писали своё множество (Set), чтобы хранить только уникальные объекты, причем эффективно. Дальнейшим развитием идеи мог бы стать словарь (Ductionary, Map). Для его эффективной реализации нам бы стоило познакомиться с типом данных 2-3-дерево или красно-черное дерево.

В общем случае, как вы уже поняли, деревья используют для поиска и для создания какого-то порядка (можно написать сортировку с помощью деревьев). И если развить эту идею дальше, то можно написать алгоритм принятия решений. В этом случае мы будем в дереве искать ответы на вопросы и в результате спустимся к листьям - ответам.

На этой практике мы будем реализовывать структуру данных для принятия решений, используя дерево поиска. На самом деле это уже структура данных из семейства алгоритмов машинного обучения. Таким образом наше дерево в процессе своего построения будет учиться принимать решения и в конце предсказывать некоторые результаты на основе новых данных. То есть будет искать ответы на вопросы, оно же дерево поиска!

Небольшая теория машинного обучения

Пусть у нас есть некоторый набор данных, который удобно представить в виде таблицы, где в строках - объекты.

Рост Вес Гендер
180 75 М
168 50 Ж
175 60 М
190 70 Ж

Рост, Вес, Гендер называют признаками объектов. Если мы хотим построить модель предсказания пола человека по его весу и росту, тогда мы будем говорить, что Рост, Вес - это известные признаки $X$, а Гендер - это целевой признак(переменная) $Y$.

Задача машинного обучения - это найти некоторую функцию f, что $f(X) -> Y$. Предполагается, что $f$ - это некоторый “закон природы”, который мы можем изучить в процессе обучения, чтобы потом предсказывать значения целевой переменной на новых данных, которые мы никогда не видели.

Рассмотрим пример дерева для датасета о выживших на Титанике.

титаник

Как мы видим в узлах задаются вопросы к объектам по какому-то признаку. А в листьях ответы - предсказания.

Но как по большому набору данных построить одно дерево так, чтобы оно лучше всего предсказывало целевую переменную?

Алгоритм

Главная идея деревьев в том, что они разбивают пространство напополам - левое и правое поддерево. В дереве поиска мы определяли куда пойти дальше, сравнивая новый объект с корнем и если объект больше, то идем в правое поддерево, иначе - в левое. Но в данном случае, мы не можем просто взять и вставить всю обучающую выборку в дерево поиска. Во первых - объекты сами по себе не сравнимы, сравниваются признаки объектов. В узлах мы будем сравнивать признаки. Во вторых - дерево получилось бы огромным, да и мы не решаем задачу поиска существующих элементов, мы решаем задачу классификации несуществующих объектов.

Таким образом дерево принятия решений в каждом своем узле делит все объекты на 2 части по одному какому-то признаку с целью как можно быстрее получить результат предсказания. Это как Акинатор!

На каждом своём шаге алгоритм ищет наилучшее, в некотором смысле, разбиение всех объектов на 2 части.

Что значит ищет? Значит, что алгоритм перебирает все возможные разбиения всех объектов по всем их признакам (O(n^2)), вычисляет качество для каждого разбиения и выбирает то разбиение, для которого качество наилучшее.

Разбиение

по каждому признаку:
    по каждому объекту выборки:
        вычислить разбиение всех объектов на 2 части используя для разбиения значение объект[признак].
        вычислить коэффициент качества
вернуть лучшее разбиение

Что такое качество разбиения? Это некоторая метрика, которая получает на вход два массива объектов, вычисляет какое-то число, которое чем больше тем лучше. Например, если у нас дана выборка жирно выделенных и нет чисел [1,2,3,4,5,6,7], мы создаем 2 разбиения [1] и [2,3,4,5,6,7], то мы говорим, что оно Ужасно плохое, так как нам надо еще несколько шагов для принятия окончательного ответа. А разбиение [1,2,3],[4,5,6,7] - идеальное! Почему мы понимаем, что качество у первого разбиения хуже? Потому что он слишком много взял объектов разных классов в правое разбиение.

Это как если в игре Акинатор бы перебирал всех актеров, вместо того чтобы спросить пол, возраст и страну для начала, тем самым уменьшив количество вариантов с каждым шагом вдвое.

Наивный алгоритм вычисления качества. Посчитаем долю объектов разных классов в левом и правом разбиении (по сути - это вероятность выбора объекта этого класса в этом множестве). Посчитаем в лоб для left=[1] и right=[2,3,4,5,6,7]. p(bold_class, left) == 1.0, p(regular_class, left) == 0.0, p(bold_class, right) == 2/6=0.33, p(regular_class, left) == 0.67. У нас есть по два числа, но нужно одно чтобы определять где лучше! Можно посчитать сумму квадратов 1.0 ** 2 + 0.0**2 + 0.33**2 + 0.67**2 == 1.5578. Для второго разбиения, где у нас всё идеально, качество будет равно 2. Наихудшее разбиение в свою очередь даст 0.5**2 + 0.5**2 + 0.5**2 + 0.5**2 == 1.

Критерий качества Gini

Улучшим наивный критерий. Для этого всего лишь добавим нормализацию коэффициентов. Но тогда лучший классификатор даст 0. А чем хуже, тем больше значение коэффициента.

Для левого и правого разбиения:
    Для каждого класса:
        p = считаем доли каждого класса (вероятности)
        score += p**2
    нормализуем коэффициент:
        gini += (1-score) * (размер этой группы / всего объектов в левом и правом разбиении)
return gini

Итак мы умеем выбирать лучшее разбиение данных на две части. То есть мы умеем выбирать 1 узел и 2 листа. Но нам нужно же построить целое дерево. ??? . Правильно, строим дальше поддеревья рекурсивно до тех пор пока нужно.

Когда наступает остановка разбиения? Первое и интуитивное - когда разбиение больше не нужно, так как у нас в листе все объекты одного класса. Поэтому мы можем просто сказать - всё, если сюда забрели, то тут мы предсказываем, что объект класса А. В принципе нам этого условия будет достаточно.

Реализация алгоритма

Лучше всего сделать некоторый класс - контейнер нашего обученного дерева. Зададим следующий интерфейс, он нам еще понадобится:

class DecisionTree:
    def fit(X, y):
        # тут тренировка модели
        pass

    def predict(X):
        # тут мы возвращаем предсказание
        pass

    def visualise():
        # не обязательно, но хотелось бы посмотреть как дерево внутри построилось
        pass

Тесты писать обязательно! На каждый метод - на разбиение, на коэффициент качества разбиение, на определения терминальности разбиения.

Да как оно должно работать то?

Тренировка модели

  1. выбрать наилучшее разбиение выборки
    1. перебрать все признаки(столбцы) и все объекты(строки)
      1. построить разбиение выборки по значению в данной строке и столбце
      2. вычислить коэффициент качества разбиения
      3. выбираем лучшее разбиение
    2. проверяем терминальность каждого разбиения
      • если там все объекты одного класса(Y), тогда это конце
      • иначе рекурсивно выбираем наилучшее разбиение по подвыборке

В листьях хранятся метки классов, в простом случае просто 0 и 1.

В узлах дерева - условие вида Признак_i < Значения, например, x['Рост'] < 170.

Предсказание

Спускаемся объектом по дереву вниз, применяя условия из узлов дерева к объекту, до листа - самого предсказание.

Использование алгоритма

Настало время научить наше дерево предсказывать что-нибудь. Интересные датасеты: Титаник, банкноты, ирисы.

Чтобы проверять насколько наше дерево хорошо справляется с задачей мы будем вычислять интуитивную метрику качества - долю правильных ответов(accuracy). Есть еще метрика F1, которая учитывает ошибки первого и второго рода, но можно и без неё.

Hints

  • pandas - библиотека для работы с таблицами и csv
  • numpy - численный питон - быстрые алгоритмы над многомерными массивами
  • matplotlib, Seaborn - визуализация
  • scipy - статистические функции
  • sklearn - библиотека алгоритмов машинного обучения

Литература