from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
import matplotlib.pyplot as plt
# ---------------------------
# ---------------------------
# Assume your dataset is already loaded into these variables:
# X_train, y_train, X_val, y_val
# Uncomment the following block if you need dummy data for testing.
# Replace this with your actual dataset.
# ------------------------------------------------------------------------------
# num_samples_train = 1000
# num_features = 30 # e.g., 10 limbs * 3 components each
# num_classes = 5 # number of pose categories
# # Dummy feature data: random numbers (in real cases, these are your pose vectors)
# X_train = np.random.random((num_samples_train, num_features))
# X_val = np.random.random((num_samples_val, num_features))
# # Dummy label data: random integers converted to one-hot encoding
# y_train = tf.keras.utils.to_categorical(np.random.randint(0, num_classes, num_samples_train), num_classes=num_classes)
# y_val = tf.keras.utils.to_categorical(np.random.randint(0, num_classes, num_samples_val), num_classes=num_classes)
# ------------------------------------------------------------------------------
# ---------------------------
# ---------------------------
Dense(64, input_dim=30, activation='relu'), # input_dim should match the number of features
Dropout(0.5), # Helps to prevent overfitting
Dense(32, activation='relu'),
Dense(5, activation='softmax') # Output layer: 5 classes with softmax activation
# ---------------------------
# ---------------------------
loss='categorical_crossentropy', # Suitable for multi-class classification
# ---------------------------
# ---------------------------
# Using EarlyStopping to monitor validation loss and avoid overfitting
early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
validation_data=(X_val, y_val),
# ---------------------------
# ---------------------------
val_loss, val_accuracy = model.evaluate(X_val, y_val)
print("Validation Loss: {:.4f}".format(val_loss))
print("Validation Accuracy: {:.4f}".format(val_accuracy))
# ---------------------------
# 6. Plot Training History
# ---------------------------
plt.figure(figsize=(12, 4))
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Loss over Epochs')
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy over Epochs')