///////////////////////////////////////////////////////////////////////////////
// TrefoilKnot.h
// =============
// N-leaf Trefoil Knot geometry for OpenGL with (majorR, minorR, tubeRadius, n, sectors, sides)
// x = r * sin(a) - R * sin((n-1)*a)
// y = r * cos(a) + R * cos((n-1)*a)
// z = r * sin(n*a)
// where n = # of leaves
// 
// The minimum # of sectors is 18 and # of sides are 2.
// The minimum # of n is 2, and it becomes 2-leaf clover shape knot.
// - major radius(R): distance from the origin to the centre of the torus
// - minor radius(r): radius of the torus tube
// - tube radius: radius of the knot tube
// - n: # of leaves of the knot
// - sectors: # of sectors of the tube
// - sides: # of sides of the tube
// - smooth: smooth (default) or flat shading
// - up-axis: facing direction, X=1, Y=2, Z=3(default)
// - mode: tube generation mode, 0=projection, 1=transform
//
//  AUTHOR: Song Ho Ahn (song.ahn@gmail.com)
// CREATED: 2026-02-03
// UPDATED: 2026-02-03
///////////////////////////////////////////////////////////////////////////////

#ifndef GEOMETRY_TREFOILKNOT_H
#define GEOMETRY_TREFOILKNOT_H

#include <vector>
#include <iostream>
#include <cmath>

class TrefoilKnot
{
public:
    // ctor/dtor
    TrefoilKnot(float majorRadius=1.0f, float minorRadius=0.5f, float tubeRadius=0.2f, int n=3, int sectorCount=90, int sideCount=18, bool smooth=true, int up=3, int mode=0);
    ~TrefoilKnot() {}

    // getters/setters
    float getMajorRadius() const            { return majorRadius; }
    float getMinorRadius() const            { return minorRadius; }
    float getTubeRadius() const             { return tubeRadius; }
    int getN() const                        { return n; }
    int getSectorCount() const              { return sectorCount; }
    int getSideCount() const                { return sideCount; }
    int getUpAxis() const                   { return upAxis; }
    int getMode() const                     { return mode; }
    void set(float majorRadius, float minorRadius, float tubeRadius, int n, int sectorCount, int sideCount, bool smooth=true, int up=3, int mode=0);
    void setMajorRadius(float radius);
    void setMinorRadius(float radius);
    void setTubeRadius(float radius);
    void setN(int n);
    void setSectorCount(int sectorCount);
    void setSideCount(int sideCount);
    void setSmooth(bool smooth);
    void setUpAxis(int up);
    void setMode(int mode);
    void reverseNormals();

    // for vertex data
    unsigned int getVertexCount() const     { return (unsigned int)vertices.size() / 3; }
    unsigned int getNormalCount() const     { return (unsigned int)normals.size() / 3; }
    unsigned int getTexCoordCount() const   { return (unsigned int)texCoords.size() / 2; }
    unsigned int getIndexCount() const      { return (unsigned int)indices.size(); }
    unsigned int getLineIndexCount() const  { return (unsigned int)lineIndices.size(); }
    unsigned int getTriangleCount() const   { return getIndexCount() / 3; }
    unsigned int getVertexSize() const      { return (unsigned int)vertices.size() * sizeof(float); }
    unsigned int getNormalSize() const      { return (unsigned int)normals.size() * sizeof(float); }
    unsigned int getTexCoordSize() const    { return (unsigned int)texCoords.size() * sizeof(float); }
    unsigned int getIndexSize() const       { return (unsigned int)indices.size() * sizeof(unsigned int); }
    unsigned int getLineIndexSize() const   { return (unsigned int)lineIndices.size() * sizeof(unsigned int); }
    const float* getVertices() const        { return vertices.data(); }
    const float* getNormals() const         { return normals.data(); }
    const float* getTexCoords() const       { return texCoords.data(); }
    const unsigned int* getIndices() const  { return indices.data(); }
    const unsigned int* getLineIndices() const  { return lineIndices.data(); }

    // for interleaved vertices: V/N/T
    unsigned int getInterleavedVertexCount() const  { return getVertexCount(); }    // # of vertices
    unsigned int getInterleavedVertexSize() const   { return (unsigned int)interleavedVertices.size() * sizeof(float); }    // # of bytes
    int getInterleavedStride() const                { return interleavedStride; }   // should be 32 bytes
    const float* getInterleavedVertices() const     { return interleavedVertices.data(); }

    // draw in VertexArray mode
    void draw() const;                                  // draw surface
    void drawLines(const float lineColor[4]) const;     // draw lines only
    void drawWithLines(const float lineColor[4]) const; // draw surface and lines
    void drawPath(const float lineColor[4]) const;      // draw tube path line
    void drawFirstPoints(const float pointColor[4]) const;     // draw points only

    // debug
    void printSelf() const;

protected:

private:
    // nested struct to extrude contour along path
    struct vec3
    {
        float x;
        float y;
        float z;
        vec3(float x=0, float y=0, float z=0) : x(x), y(y), z(z) {};
        vec3& normalize()
        {
            float length = sqrtf(x*x + y*y + z*z);
            if(length > 0.000001f)
            {
                float lenInv = 1.0f / length;
                x *= lenInv;
                y *= lenInv;
                z *= lenInv;
            }
            return *this;
        };
        float distance(const vec3& v) const { return sqrtf((v.x-x)*(v.x-x) + (v.y-y)*(v.y-y) + (v.z-z)*(v.z-z)); }
        vec3 cross(const vec3& v) const { return vec3(y*v.z - z*v.y, z*v.x - x*v.z, x*v.y - y*v.x); };
        float dot(const vec3& v) const { return x*v.x + y*v.y + z*v.z; };
        vec3 operator+(const vec3& v) const { return vec3(x+v.x, y+v.y, z+v.z); };
        vec3 operator-(const vec3& v) const { return vec3(x-v.x, y-v.y, z-v.z); };
        vec3 operator*(float s) const { return vec3(x*s, y*s, z*s); };
        vec3& operator+=(const vec3& v) { x += v.x; y += v.y; z += v.z; return *this; };
        vec3& operator-=(const vec3& v) { x -= v.x; y -= v.y; z -= v.z; return *this;};
        vec3& operator*=(float s) { x *= s; y *= s; z *= s; return *this;};
        friend std::ostream& operator<<(std::ostream& os, const vec3& v) { os << "(" << v.x << ", " << v.y << ", " << v.z << ")"; return os; };
    };

    // nested struct to store tmp vertex for flat shading
    struct Vertex
    {
        float x, y, z, s, t;
    };


    // member functions
    void buildPath();
    void buildContour();
    std::vector<vec3> projectContour(std::vector<vec3>& fromContour, int fromIndex, int toIndex);
    std::vector<vec3> transformContour(int fromIndex, int toIndex);
    void transformFirstContour();
    void buildVerticesSmooth();
    void buildVerticesFlat();
    void buildInterleavedVertices();
    void alignContoursSmooth();
    void alignContoursFlat(std::vector<Vertex>& vertices);
    void changeUpAxis(int from, int to);
    void clearArrays();
    void addVertex(float x, float y, float z);
    void addNormal(float x, float y, float z);
    void addTexCoord(float s, float t);
    void addIndices(unsigned int i1, unsigned int i2, unsigned int i3);
    vec3 computeFaceNormal(float x1, float y1, float z1,
                           float x2, float y2, float z2,
                           float x3, float y3, float z3);

    // memeber vars
    float majorRadius;
    float minorRadius;
    float tubeRadius;
    int n;                                  // # of leaves
    int sectorCount;                        // # of sectors(rings)
    int sideCount;                          // # of sides
    bool smooth;
    int upAxis;                             // +X=1, +Y=2, +z=3 (default)
    std::vector<float> vertices;
    std::vector<float> normals;
    std::vector<float> texCoords;
    std::vector<unsigned int> indices;
    std::vector<unsigned int> lineIndices;

    // interleaved
    std::vector<float> interleavedVertices;
    int interleavedStride;                  // # of bytes to hop to the next vertex (should be 32 bytes)

    // for tube path
    std::vector<vec3> path;
    std::vector<unsigned int> pathIndices;
    std::vector<vec3> pathDirections;       // path direction vectors
    std::vector<vec3> contour;              // contour at the first path point
    std::vector<vec3> contourNormal;        // normal at the first path point
    std::vector<vec3> firstVectors;         // direction of first vertex of each contour
    int mode;                               // tube generation mode: 0=project, 1=transform
};

#endif
