import numpy as np
# Modify the line below based on your last name
# for example:
# from Kamangar_01_01 import multi_layer_nn
from Your_last_name_01_01 import multi_layer_nn


def sigmoid(x):
    # This function calculates the sigmoid function
    return 1 / (1 + np.exp(-x))


def create_toy_data(input_dimension=4, output_dimension=1, n_samples=100, validation_ratio=0.1, form="quadratic",
                    seed=1234):
    np.random.seed(seed)
    X = np.random.uniform(-1, 1, (n_samples, input_dimension))
    b = np.random.uniform(-1, 1, output_dimension)  # Constant
    c = np.random.uniform(-1, 1, (input_dimension, output_dimension))  # Coefficients
    match form.lower():
        case "linear":
            Y = np.dot(X, c) + b
        case "quadratic":
            Q = np.random.uniform(-1, 1, (output_dimension, input_dimension, input_dimension))
            Q = 0.5 * (Q + Q.transpose(0, 2, 1))  # make Q_k symmetric
            quad = np.einsum('ni,kij,nj->nk', X, Q, X)  # quadratic term for each output
            Y = b + np.dot(X, c) + quad
        case _:  # Create random Y
            Y = np.random.randn(n_samples, output_dimension)
    validation_index = np.floor((1.0 - validation_ratio) * n_samples).astype(int)
    X_train = X[:validation_index, :]
    Y_train = Y[:validation_index, :]
    X_test = X[validation_index:, :]
    Y_test = Y[validation_index:, :]
    return X_train, Y_train, X_test, Y_test


def test_can_fit_data():
    np.random.seed(12345)
    X_train, Y_train, X_test, Y_test = create_toy_data(n_samples=100)
    [W, err, Out] = multi_layer_nn(X_train, Y_train, X_test, Y_test, layers=[2, Y_train.shape[1]], alpha=0.001,
                                   epochs=10, h=0.0001, seed=1234)
    assert abs(err[9] - 1.0108) < 1e-3


def test_can_fit_data_2d():
    np.random.seed(1234)
    X_train, Y_train, X_test, Y_test = create_toy_data(n_samples=100, output_dimension=2)

    [W, err, Out] = multi_layer_nn(X_train, Y_train, X_test, Y_test, layers=[2, Y_train.shape[1]], alpha=0.1, epochs=10,
                                   h=1e-5, seed=1234)
    assert err[1] < err[0]
    assert err[2] < err[1]
    assert err[3] < err[2]
    assert err[9] < 0.9
    assert abs(err[9] - 0.8333) < 1e-3


def test_check_weight_init():
    np.random.seed(12345)
    X_train, Y_train, X_test, Y_test = create_toy_data(n_samples=110, output_dimension=3)
    [W, err, Out] = multi_layer_nn(X_train, Y_train, X_test, Y_test, layers=[2, 4, 3, Y_train.shape[1]], alpha=0.35,
                                   epochs=0, h=1e-8, seed=1234)
    np.set_printoptions(threshold=np.inf)
    print(repr(W[0]))
    print(repr(W[1]))
    print(repr(W[2]))
    print(repr(W[3]))

    assert np.allclose(W[0], np.array([[0.47143516, -1.19097569],
                                       [1.43270697, -0.3126519],
                                       [-0.72058873, 0.88716294],
                                       [0.85958841, -0.6365235],
                                       [0.01569637, -2.24268495]]))
    assert np.allclose(W[1], np.array([[0.47143516, -1.19097569, 1.43270697, -0.3126519],
                                       [-0.72058873, 0.88716294, 0.85958841, -0.6365235],
                                       [0.01569637, -2.24268495, 1.15003572, 0.99194602]]))
    assert np.allclose(W[2], np.array([[0.47143516, -1.19097569, 1.43270697],
                                       [-0.3126519, -0.72058873, 0.88716294],
                                       [0.85958841, -0.6365235, 0.01569637],
                                       [-2.24268495, 1.15003572, 0.99194602],
                                       [0.95332413, -2.02125482, -0.33407737]]))
    assert np.allclose(W[3], np.array([[0.47143516, -1.19097569, 1.43270697],
                                       [-0.3126519, -0.72058873, 0.88716294],
                                       [0.85958841, -0.6365235, 0.01569637],
                                       [-2.24268495, 1.15003572, 0.99194602]]))


def test_large_alpha_test():
    # if alpha is too large, the weights will change too much with each update, and the error will either increase or not improve much

    np.random.seed(12345)
    X_train, Y_train, X_test, Y_test = create_toy_data(n_samples=110)

    [W, err, Out] = multi_layer_nn(X_train, Y_train, X_test, Y_test, layers=[2, Y_train.shape[1]], alpha=0.9, epochs=10,
                                   h=1, seed=2)
    assert err[-1] > 1


def test_small_alpha_test():
    # if the alpha value is very small (e.g. 1e-9), the weights should not change much with each update, and the error should not decrease
    np.random.seed(12345)
    X_train, Y_train, X_test, Y_test = create_toy_data(n_samples=110)

    [W, err, Out] = multi_layer_nn(X_train, Y_train, X_test, Y_test, layers=[2, Y_train.shape[1]], alpha=1e-9,
                                   epochs=10, h=1e-8, seed=2)
    assert abs(err[-1] - err[-2]) < 1e-5
    assert abs(err[1] - err[0]) < 1e-5


def test_number_of_nodes_test():
    # check if the number of nodes is being used in creating the weight matrices
    np.random.seed(12345)
    X_train, Y_train, X_test, Y_test = create_toy_data(n_samples=11)

    [W, err, Out] = multi_layer_nn(X_train, Y_train, X_test, Y_test, layers=[100, Y_train.shape[1]], alpha=1e-9,
                                   epochs=0, h=1e-8, seed=2)

    assert W[0].shape == (5, 100)
    assert W[1].shape == (101, 1)

    [W, err, Out] = multi_layer_nn(X_train, Y_train, X_test, Y_test, layers=[42, Y_train.shape[1]], alpha=1e-9,
                                   epochs=0, h=1e-8, seed=2)
    assert W[0].shape == (5, 42)
    assert W[1].shape == (43, 1)

    X_train, Y_train, X_test, Y_test = create_toy_data(input_dimension=5, n_samples=10, output_dimension=9)
    [W, err, Out] = multi_layer_nn(X_train, Y_train, X_test, Y_test, layers=[42, Y_train.shape[1]], alpha=1e-9,
                                   epochs=0, h=1e-8, seed=2)
    assert W[0].shape == (6, 42)
    assert W[1].shape == (43, 9)


def test_check_output_shape_2d():
    X_train, Y_train, X_test, Y_test = create_toy_data(n_samples=110, output_dimension=7)

    [W, err, Out] = multi_layer_nn(X_train, Y_train, X_test, Y_test, [2, Y_train.shape[1]], alpha=0.35, epochs=1,
                                   h=1e-8, seed=1234)
    assert Out.shape == Y_test.shape


def test_check_output_values():
    X_train, Y_train, X_test, Y_test = create_toy_data(n_samples=110, output_dimension=2)

    [W, err, Out] = multi_layer_nn(X_train, Y_train, X_test, Y_test, layers=[2, Y_train.shape[1]], alpha=0.35, epochs=0,
                                   h=1e-8, seed=1234)
    np.set_printoptions(threshold=np.inf)
    print(repr(Out))

    expected_Out = np.array([[0.46675414, -0.69148138],
                             [0.69011347, -0.75985168],
                             [0.53205552, -0.61500012],
                             [1.07960278, -1.25175119],
                             [1.42686729, -1.35043939],
                             [1.45047611, -1.10725796],
                             [1.35948139, -1.28415832],
                             [1.31103035, -1.28881301],
                             [0.41432377, -0.59684387],
                             [1.62037671, -1.2381968],
                             [1.1523584, -1.26656577]])

    assert np.allclose(Out, expected_Out, atol=1e-5)


def test_check_weight_update():
    X_train, Y_train, X_test, Y_test = create_toy_data(n_samples=110, output_dimension=2)
    [W_before, err, Out] = multi_layer_nn(X_train, Y_train, X_test, Y_test, layers=[2, Y_train.shape[1]], alpha=0.2,
                                          epochs=0, h=1e-8, seed=1234)
    [W_after, err, Out] = multi_layer_nn(X_train, Y_train, X_test, Y_test, layers=[2, Y_train.shape[1]], alpha=0.2,
                                         epochs=1, h=1e-8, seed=1234)
    delta1 = (W_after[0] - W_before[0])
    delta2 = (W_after[1] - W_before[1])
    np.set_printoptions(threshold=np.inf)
    print(repr(delta1))
    print(repr(delta2))

    correct_delta1 = np.array([[-0.9470275, -0.43531655],
                               [-0.16798027, -0.15437864],
                               [1.15226623, -0.74260675],
                               [0.31322348, -0.14407453],
                               [-0.3414342, 0.13365055]])
    correct_delta2 = np.array([[0.46874882, 1.43979206],
                               [-1.54724004, -0.95096358],
                               [1.06605661, -0.41162799]])
    assert np.allclose(delta1, correct_delta1, atol=1e-5)

    assert np.allclose(delta2, correct_delta2, atol=1e-5)


def test_h_value_used():
    X_train, Y_train, X_test, Y_test = create_toy_data(n_samples=110, output_dimension=2)
    [W_before, err, Out] = multi_layer_nn(X_train, Y_train, X_test, Y_test, layers=[2, Y_train.shape[1]], alpha=0.2,
                                          epochs=0, h=1e-8, seed=1234)
    np.random.seed(1234)
    [W_after, err, Out] = multi_layer_nn(X_train, Y_train, X_test, Y_test, layers=[2, Y_train.shape[1]], alpha=0.2,
                                         epochs=1, h=10, seed=1234)
    # if we use some large value for h instead of 1e-8, we should get a different result
    # this will check if the h value is used

    assert not np.allclose(W_after[0], W_before[0], atol=1e-5)
    assert not np.allclose(W_after[1], W_before[1], atol=1e-5)
