Играем в GTA V c Python. Часть XI: Обучение нейронной сети для автопилота

Предыдущая статья — Играем в GTA V c Python. Часть X: балансировка данных для обучения нейронной сети.

Добро пожаловать в одиннадцатую часть серии статей про применение методов машинного обучения в игре Grand Theft Auto V. Мы продолжаем создавать наш беспилотный автомобиль в этой игре.

К настоящему моменту мы построили датасет, состоящий из изображений 800Х600, каждому из которых соответствует клавиатурный ввод A,W, и D(налево, прямо и направо соответственно).

Далее нам нужно создать и обучить нейронную сеть. Для визуальных данных, как правило, чаще всего используются сверточные нейронные сети (CNN). Здесь у нас есть огромный выбор уже готовых нейронных сетей с различным количеством слоев, узлов и с разными видами функций активации. На данном этапе мы просто выбрали из них AlexNet.

Мы будем использовать библиотеку TFLearn (библиотеку абстракций TensorFlow). Если вы хотите использовать TensorFlow в чистом виде, или Keras, или Teano, или что-то еще, ради бога! А если вы хотите узнать о библиотеке TFLearn больше, то вот пособие к ней.

Итак, давайте зададим нашу модель. Наверно, для этого лучше создать отдельный файл, назовем его alexnet.py.

# alexnet.py

import tflearn
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.estimator import regression
from tflearn.layers.normalization import local_response_normalization

def alexnet(width, height, lr):
    network = input_data(shape=[None, width, height, 1], name='input')
    network = conv_2d(network, 96, 11, strides=4, activation='relu')
    network = max_pool_2d(network, 3, strides=2)
    network = local_response_normalization(network)
    network = conv_2d(network, 256, 5, activation='relu')
    network = max_pool_2d(network, 3, strides=2)
    network = local_response_normalization(network)
    network = conv_2d(network, 384, 3, activation='relu')
    network = conv_2d(network, 384, 3, activation='relu')
    network = conv_2d(network, 256, 3, activation='relu')
    network = max_pool_2d(network, 3, strides=2)
    network = local_response_normalization(network)
    network = fully_connected(network, 4096, activation='tanh')
    network = dropout(network, 0.5)
    network = fully_connected(network, 4096, activation='tanh')
    network = dropout(network, 0.5)
    network = fully_connected(network, 3, activation='softmax')
    network = regression(network, optimizer='momentum',
                         loss='categorical_crossentropy',
                         learning_rate=lr, name='targets')

    model = tflearn.DNN(network, checkpoint_path='model_alexnet',
                        max_checkpoints=1, tensorboard_verbose=2, tensorboard_dir='log')

    return model 

Теперь поработаем над файлом train_model.py:

# train_model.py

import numpy as np
from alexnet import alexnet

WIDTH = 80
HEIGHT = 60
LR = 1e-3
EPOCHS = 8
MODEL_NAME = 'pygta5-car-{}-{}-{}-epochs.model'.format(LR, 'alexnetv2',EPOCHS)

model = alexnet(WIDTH, HEIGHT, LR)
Setup the training data:

train_data = np.load('training_data_v2.npy')

train = train_data[:-500]
test = train_data[-500:]

X = np.array([i[0] for i in train]).reshape(-1,WIDTH,HEIGHT,1)
Y = [i[1] for i in train]

test_x = np.array([i[0] for i in test]).reshape(-1,WIDTH,HEIGHT,1)
test_y = [i[1] for i in test]
Now we can actually train the model with:

model.fit({'input': X}, {'targets': Y}, n_epoch=EPOCHS, validation_set=({'input': test_x}, {'targets': test_y}), 
    snapshot_step=500, show_metric=True, run_id=MODEL_NAME)

# tensorboard --logdir=foo:C:/Users/H/Desktop/ai-gaming/log

model.save(MODEL_NAME) 

Если вы хотите видеть прогресс обучения вашей нейронной сети, то можете запустить в терминале команду, которая закоментирована в конце файла.

Завершив обучение, мы готовы к следующему шагу.

Мы в обучении использовали 8 эпох и рекомендуем вам исходить из количества от 5 до 15. Если вам нужно больше информации по выбору количества эпох и по Tensor Board, смотрите наше видео.

Следующая статья — Играем в GTA V c Python. Часть XII: тестируем нейронную сеть для автопилота.