# Lint as: python3 # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== # pylint: disable=g-bad-import-order """Build and train neural networks.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import datetime import os # pylint: disable=duplicate-code from data_load import DataLoader import numpy as np # pylint: disable=duplicate-code import tensorflow as tf logdir = "logs/scalars/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir) def reshape_function(data, label): reshaped_data = tf.reshape(data, [-1, 3, 1]) return reshaped_data, label def calculate_model_size(model): print(model.summary()) var_sizes = [ np.product(list(map(int, v.shape))) * v.dtype.size for v in model.trainable_variables ] print("Model size:", sum(var_sizes) / 1024, "KB") def build_cnn(seq_length): """Builds a convolutional neural network in Keras.""" model = tf.keras.Sequential([ tf.keras.layers.Conv2D( 8, (4, 3), padding="same", activation="relu", input_shape=(seq_length, 3, 1)), # output_shape=(batch, 128, 3, 8) tf.keras.layers.MaxPool2D((3, 3)), # (batch, 42, 1, 8) tf.keras.layers.Dropout(0.1), # (batch, 42, 1, 8) tf.keras.layers.Conv2D(16, (4, 1), padding="same", activation="relu"), # (batch, 42, 1, 16) tf.keras.layers.MaxPool2D((3, 1), padding="same"), # (batch, 14, 1, 16) tf.keras.layers.Dropout(0.1), # (batch, 14, 1, 16) tf.keras.layers.Flatten(), # (batch, 224) tf.keras.layers.Dense(16, activation="relu"), # (batch, 16) tf.keras.layers.Dropout(0.1), # (batch, 16) tf.keras.layers.Dense(4, activation="softmax") # (batch, 4) ]) model_path = os.path.join("./netmodels", "CNN") print("Built CNN.") if not os.path.exists(model_path): os.makedirs(model_path) model.load_weights("./netmodels/CNN/weights.h5") return model, model_path def build_lstm(seq_length): """Builds an LSTM in Keras.""" model = tf.keras.Sequential([ tf.keras.layers.Bidirectional( tf.keras.layers.LSTM(22), input_shape=(seq_length, 3)), # output_shape=(batch, 44) tf.keras.layers.Dense(4, activation="sigmoid") # (batch, 4) ]) model_path = os.path.join("./netmodels", "LSTM") print("Built LSTM.") if not os.path.exists(model_path): os.makedirs(model_path) return model, model_path def load_data(train_data_path, valid_data_path, test_data_path, seq_length): data_loader = DataLoader( train_data_path, valid_data_path, test_data_path, seq_length=seq_length) data_loader.format() return data_loader.train_len, data_loader.train_data, data_loader.valid_len, \ data_loader.valid_data, data_loader.test_len, data_loader.test_data def build_net(args, seq_length): if args.model == "CNN": model, model_path = build_cnn(seq_length) elif args.model == "LSTM": model, model_path = build_lstm(seq_length) else: print("Please input correct model name.(CNN LSTM)") return model, model_path def train_net( model, model_path, # pylint: disable=unused-argument train_len, # pylint: disable=unused-argument train_data, valid_len, valid_data, test_len, test_data, kind): """Trains the model.""" calculate_model_size(model) epochs = 50 batch_size = 64 model.compile( optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) if kind == "CNN": train_data = train_data.map(reshape_function) test_data = test_data.map(reshape_function) valid_data = valid_data.map(reshape_function) test_labels = np.zeros(test_len) idx = 0 for data, label in test_data: # pylint: disable=unused-variable test_labels[idx] = label.numpy() idx += 1 train_data = train_data.batch(batch_size).repeat() valid_data = valid_data.batch(batch_size) test_data = test_data.batch(batch_size) model.fit( train_data, epochs=epochs, validation_data=valid_data, steps_per_epoch=1000, validation_steps=int((valid_len - 1) / batch_size + 1), callbacks=[tensorboard_callback]) loss, acc = model.evaluate(test_data) pred = np.argmax(model.predict(test_data), axis=1) confusion = tf.math.confusion_matrix( labels=tf.constant(test_labels), predictions=tf.constant(pred), num_classes=4) print(confusion) print("Loss {}, Accuracy {}".format(loss, acc)) # Convert the model to the TensorFlow Lite format without quantization converter = tf.lite.TFLiteConverter.from_keras_model(model) tflite_model = converter.convert() # Save the model to disk open("model.tflite", "wb").write(tflite_model) # Convert the model to the TensorFlow Lite format with quantization converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] tflite_model = converter.convert() # Save the model to disk open("model_quantized.tflite", "wb").write(tflite_model) basic_model_size = os.path.getsize("model.tflite") print("Basic model is %d bytes" % basic_model_size) quantized_model_size = os.path.getsize("model_quantized.tflite") print("Quantized model is %d bytes" % quantized_model_size) difference = basic_model_size - quantized_model_size print("Difference is %d bytes" % difference) if __name__ == "__main__": parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument("--model", "-m") parser.add_argument("--person", "-p") args = parser.parse_args() seq_length = 128 print("Start to load data...") if args.person == "true": train_len, train_data, valid_len, valid_data, test_len, test_data = \ load_data("./person_split/train", "./person_split/valid", "./person_split/test", seq_length) else: train_len, train_data, valid_len, valid_data, test_len, test_data = \ load_data("./data/train", "./data/valid", "./data/test", seq_length) print("Start to build net...") model, model_path = build_net(args, seq_length) print("Start training...") train_net(model, model_path, train_len, train_data, valid_len, valid_data, test_len, test_data, args.model) print("Training finished!")