You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
202 lines
7.4 KiB
202 lines
7.4 KiB
# Lint as: python3 |
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
# ============================================================================== |
|
# pylint: disable=g-bad-import-order |
|
|
|
"""Build and train neural networks.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import argparse |
|
import datetime |
|
import os # pylint: disable=duplicate-code |
|
from data_load import DataLoader |
|
|
|
import numpy as np # pylint: disable=duplicate-code |
|
import tensorflow as tf |
|
|
|
logdir = "logs/scalars/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") |
|
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir) |
|
|
|
|
|
def reshape_function(data, label): |
|
reshaped_data = tf.reshape(data, [-1, 3, 1]) |
|
return reshaped_data, label |
|
|
|
|
|
def calculate_model_size(model): |
|
print(model.summary()) |
|
var_sizes = [ |
|
np.product(list(map(int, v.shape))) * v.dtype.size |
|
for v in model.trainable_variables |
|
] |
|
print("Model size:", sum(var_sizes) / 1024, "KB") |
|
|
|
|
|
def build_cnn(seq_length): |
|
"""Builds a convolutional neural network in Keras.""" |
|
model = tf.keras.Sequential([ |
|
tf.keras.layers.Conv2D( |
|
8, (4, 3), |
|
padding="same", |
|
activation="relu", |
|
input_shape=(seq_length, 3, 1)), # output_shape=(batch, 128, 3, 8) |
|
tf.keras.layers.MaxPool2D((3, 3)), # (batch, 42, 1, 8) |
|
tf.keras.layers.Dropout(0.1), # (batch, 42, 1, 8) |
|
tf.keras.layers.Conv2D(16, (4, 1), padding="same", |
|
activation="relu"), # (batch, 42, 1, 16) |
|
tf.keras.layers.MaxPool2D((3, 1), padding="same"), # (batch, 14, 1, 16) |
|
tf.keras.layers.Dropout(0.1), # (batch, 14, 1, 16) |
|
tf.keras.layers.Flatten(), # (batch, 224) |
|
tf.keras.layers.Dense(16, activation="relu"), # (batch, 16) |
|
tf.keras.layers.Dropout(0.1), # (batch, 16) |
|
tf.keras.layers.Dense(4, activation="softmax") # (batch, 4) |
|
]) |
|
model_path = os.path.join("./netmodels", "CNN") |
|
print("Built CNN.") |
|
if not os.path.exists(model_path): |
|
os.makedirs(model_path) |
|
model.load_weights("./netmodels/CNN/weights.h5") |
|
return model, model_path |
|
|
|
|
|
def build_lstm(seq_length): |
|
"""Builds an LSTM in Keras.""" |
|
model = tf.keras.Sequential([ |
|
tf.keras.layers.Bidirectional( |
|
tf.keras.layers.LSTM(22), |
|
input_shape=(seq_length, 3)), # output_shape=(batch, 44) |
|
tf.keras.layers.Dense(4, activation="sigmoid") # (batch, 4) |
|
]) |
|
model_path = os.path.join("./netmodels", "LSTM") |
|
print("Built LSTM.") |
|
if not os.path.exists(model_path): |
|
os.makedirs(model_path) |
|
return model, model_path |
|
|
|
|
|
def load_data(train_data_path, valid_data_path, test_data_path, seq_length): |
|
data_loader = DataLoader( |
|
train_data_path, valid_data_path, test_data_path, seq_length=seq_length) |
|
data_loader.format() |
|
return data_loader.train_len, data_loader.train_data, data_loader.valid_len, \ |
|
data_loader.valid_data, data_loader.test_len, data_loader.test_data |
|
|
|
|
|
def build_net(args, seq_length): |
|
if args.model == "CNN": |
|
model, model_path = build_cnn(seq_length) |
|
elif args.model == "LSTM": |
|
model, model_path = build_lstm(seq_length) |
|
else: |
|
print("Please input correct model name.(CNN LSTM)") |
|
return model, model_path |
|
|
|
|
|
def train_net( |
|
model, |
|
model_path, # pylint: disable=unused-argument |
|
train_len, # pylint: disable=unused-argument |
|
train_data, |
|
valid_len, |
|
valid_data, |
|
test_len, |
|
test_data, |
|
kind): |
|
"""Trains the model.""" |
|
calculate_model_size(model) |
|
epochs = 50 |
|
batch_size = 64 |
|
model.compile( |
|
optimizer="adam", |
|
loss="sparse_categorical_crossentropy", |
|
metrics=["accuracy"]) |
|
if kind == "CNN": |
|
train_data = train_data.map(reshape_function) |
|
test_data = test_data.map(reshape_function) |
|
valid_data = valid_data.map(reshape_function) |
|
test_labels = np.zeros(test_len) |
|
idx = 0 |
|
for data, label in test_data: # pylint: disable=unused-variable |
|
test_labels[idx] = label.numpy() |
|
idx += 1 |
|
train_data = train_data.batch(batch_size).repeat() |
|
valid_data = valid_data.batch(batch_size) |
|
test_data = test_data.batch(batch_size) |
|
model.fit( |
|
train_data, |
|
epochs=epochs, |
|
validation_data=valid_data, |
|
steps_per_epoch=1000, |
|
validation_steps=int((valid_len - 1) / batch_size + 1), |
|
callbacks=[tensorboard_callback]) |
|
loss, acc = model.evaluate(test_data) |
|
pred = np.argmax(model.predict(test_data), axis=1) |
|
confusion = tf.math.confusion_matrix( |
|
labels=tf.constant(test_labels), |
|
predictions=tf.constant(pred), |
|
num_classes=4) |
|
print(confusion) |
|
print("Loss {}, Accuracy {}".format(loss, acc)) |
|
# Convert the model to the TensorFlow Lite format without quantization |
|
converter = tf.lite.TFLiteConverter.from_keras_model(model) |
|
tflite_model = converter.convert() |
|
|
|
# Save the model to disk |
|
open("model.tflite", "wb").write(tflite_model) |
|
|
|
# Convert the model to the TensorFlow Lite format with quantization |
|
converter = tf.lite.TFLiteConverter.from_keras_model(model) |
|
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] |
|
tflite_model = converter.convert() |
|
|
|
# Save the model to disk |
|
open("model_quantized.tflite", "wb").write(tflite_model) |
|
|
|
basic_model_size = os.path.getsize("model.tflite") |
|
print("Basic model is %d bytes" % basic_model_size) |
|
quantized_model_size = os.path.getsize("model_quantized.tflite") |
|
print("Quantized model is %d bytes" % quantized_model_size) |
|
difference = basic_model_size - quantized_model_size |
|
print("Difference is %d bytes" % difference) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(allow_abbrev=False) |
|
parser.add_argument("--model", "-m") |
|
parser.add_argument("--person", "-p") |
|
args = parser.parse_args() |
|
|
|
seq_length = 128 |
|
|
|
print("Start to load data...") |
|
if args.person == "true": |
|
train_len, train_data, valid_len, valid_data, test_len, test_data = \ |
|
load_data("./person_split/train", "./person_split/valid", |
|
"./person_split/test", seq_length) |
|
else: |
|
train_len, train_data, valid_len, valid_data, test_len, test_data = \ |
|
load_data("./data/train", "./data/valid", "./data/test", seq_length) |
|
|
|
print("Start to build net...") |
|
model, model_path = build_net(args, seq_length) |
|
|
|
print("Start training...") |
|
train_net(model, model_path, train_len, train_data, valid_len, valid_data, |
|
test_len, test_data, args.model) |
|
|
|
print("Training finished!")
|
|
|