Чат-бот на Python (Deep Learning + TensorFlow). Часть VI: набор данных для обучения

Предыдущая статья: Чат-бот на Python (Deep Learning + Tensorflow). Часть V: построение базы данных.

Добро пожаловать в шестую часть серии статей про создание чат-бота при помощи алгоритмов глубокого обучения и библиотеки TensorFlow. В этой части мы будем работать над созданием обучающих данных.

Сейчас у нас есть два варианта рабочих моделей, с которыми можно развивать данную серию статей. Одна из этих моделей нормально работает, а другая, вероятно, может работать еще лучше, но с ней у нас есть пока несколько проблем. Впрочем, в любом случае наши настройки для обучающих данных весьма похожи.

Нам нужно создать два файла: в первом будут родительские комментарии, а во втором — ответы на них. А номер строки будет служить идентификатором. Например, строка 15 в «родительском» (первом) файле содержит некий комментарий, а строка 15 в файле ответов, соответственно, ответ на этот комментарий.

Чтобы создать эти файлы, нам всего лишь нужно вытаскивать пары комментарий + ответ из нашей базы данных и добавлять соответствующие части этих пар в разные файлы одновременно. Начнем со следующего кода:

import sqlite3
import pandas as pd

timeframes = ['2015-05']


for timeframe in timeframes: 

Запустив этот код, мы переберем один месяц в созданной нами базе данных. Мы конечно же могли бы создать базу данных, в которой было бы несколько таблиц с разными месяцами и годами. Или, например, можно создать несколько баз данных, аналогичных нашей, для разных месяцев и годов. Как бы то ни было, у нас сейчас есть только один месяц, и мы оставим переменную timeframes хранить список с одним элементом.

[machinelearning_ad_block]

Итак, продолжим писать тело этого цикла:

for timeframe in timeframes:
    connection = sqlite3.connect('{}.db'.format(timeframe))
    c = connection.cursor()
    limit = 5000
    last_unix = 0
    cur_length = limit
    counter = 0
    test_done = False 

Первая строчка в теле цикла просто устанавливает соединение с базой данных. Затем мы определяем курсор, а после — переменную limit.

Эта переменная задает размер фрагмента, который мы собираемся извлекать из базы данных за один раз. Имейте ввиду, что мы снова работаем с данными, размер которых, вероятно, намного больше нашей оперативной памяти. Для начала установим этот размер равным 5000, так как вначале мы будем создавать тестовую выборку. В дальнейшем этот лимит можно будет увеличить.

Мы будем использовать переменную last_unix для извлечения данных из базы, cur_length сообщит нам, когда мы закончим, counter даст нам возможность выводить некоторую отладочную информацию, а test_done покажет нам, когда мы закончим создание тестовых данных.

    while     while cur_length == limit:

        df = pd.read_sql("SELECT * FROM parent_reply WHERE unix > {} and parent NOT NULL and score > 0 ORDER BY unix ASC LIMIT {}".format(last_unix,limit),connection)
        last_unix = df.tail(1)['unix'].values[0]
        cur_length = len(df) == limit:

Пока переменная cur_length равна размеру фрагмента, установленному в переменной limit, мы можем продолжать извлекать данные из базы, помещая их в датафрейм. На данный момент мы не производим никаких операций с этим датафреймом, но в дальнейшем мы можем его использовать, чтобы наложить определенные ограничения на типы данных, которые мы хотим изучать.

Мы сохраняем значение переменной last_unix, чтобы знать, с какого места нам начать следующую итерацию. Мы также фиксируем размер полученных данных.

Теперь возьмем эти данные и используем их для создания обучающей и тестовой выборки. Начнем со второй:

        if not test_done:
            with open('test.from','a', encoding='utf8') as f:
                for content in df['parent'].values:
                    f.write(content+'\n')

            with open('test.to','a', encoding='utf8') as f:
                for content in df['comment'].values:
                    f.write(str(content)+'\n')

            test_done = True 

Теперь при желании можно увеличить значение переменной limit. После строки test_done = True вы можете ее переопределить на любое значение, например на 100000.

Давайте теперь займемся написанием кода для создания обучающей выборки:

        else:
            with open('train.from','a', encoding='utf8') as f:
                for content in df['parent'].values:
                    f.write(content+'\n')

            with open('train.to','a', encoding='utf8') as f:
                for content in df['comment'].values:
                    f.write(str(content)+'\n') 

Мы могли бы сделать этот код проще и лучше, превратив его в функцию Это позволило бы обойтись без копипаста одного кода в разные места. Но давайте просто продолжим:

        counter += 1
        if counter % 20 == 0:
            print(counter*limit,'rows completed so far') 

Здесь мы будем видеть результат для каждых 20 шагов, то есть каждые 100000 пар, если мы сохраним значение переменной limit равным 5000.

Итак, сейчас весь наш код имеет следующий вид:

import sqlite3
import pandas as pd

timeframes = ['2015-05']

for timeframe in timeframes:
    connection = sqlite3.connect('{}.db'.format(timeframe))
    c = connection.cursor()
    limit = 5000
    last_unix = 0
    cur_length = limit
    counter = 0
    test_done = False

    while cur_length == limit:

        df = pd.read_sql("SELECT * FROM parent_reply WHERE unix > {} and parent NOT NULL and score > 0 ORDER BY unix ASC LIMIT {}".format(last_unix,limit),connection)
        last_unix = df.tail(1)['unix'].values[0]
        cur_length = len(df)

        if not test_done:
            with open('test.from','a', encoding='utf8') as f:
                for content in df['parent'].values:
                    f.write(content+'\n')

            with open('test.to','a', encoding='utf8') as f:
                for content in df['comment'].values:
                    f.write(str(content)+'\n')

            test_done = True

        else:
            with open('train.from','a', encoding='utf8') as f:
                for content in df['parent'].values:
                    f.write(content+'\n')

            with open('train.to','a', encoding='utf8') as f:
                for content in df['comment'].values:
                    f.write(str(content)+'\n')

        counter += 1
        if counter % 20 == 0:
            print(counter*limit,'rows completed so far') 

Хорошо, теперь запускайте этот код, и увидимся, когда у вас будут готовы данные!

Следующая статья — Чат-бот на Python (Deep Learning + TensorFlow). Часть VII: обучение модели.