/* Copyright 2024, Alejandro A. García <aag@zorzal.net>
 * SPDX-License-Identifier: Zlib
 */
#include "rng_mt19937.h"
#include <math.h>

//ref: https://en.wikipedia.org/wiki/Mersenne_Twister#C_code

#define n 624
#define m 397
#define w 32
#define r 31
#define UMASK (0xffffffffUL << r)
#define LMASK (0xffffffffUL >> (w-r))
#define a 0x9908b0dfUL
#define u 11
#define s 7
#define t 15
#define l 18
#define b 0x9d2c5680UL
#define c 0xefc60000UL
#define f 1812433253UL

MT19937State g_rng_mt19937;

void mt19937_init(MT19937State* S, uint32_t seed) 
{
    S->array[0] = seed;
    
    for (unsigned i=1; i<n; i++) {
		// Knuth TAOCP Vol2. 3rd Ed. P.106 for multiplier.
        seed = f * (seed ^ (seed >> (w-2))) + i;    
        S->array[i] = seed; 
    }
    
    S->index = 0;
	S->saved = NAN;
}

uint32_t mt19937_uint32(MT19937State* S)
{
    int k = S->index;  // point to current state location
	assert( 0 <= k && k < n );
    
    int j = k - (n-1);  // point to state n-1 iterations before
    if (j < 0) j += n;  // modulo n circular indexing

    uint32_t x = (S->array[k] & UMASK) | (S->array[j] & LMASK);
    
    uint32_t xA = x >> 1;
    if (x & 0x00000001UL) xA ^= a;
    
    j = k - (n-m);      // point to state n-m iterations before
    if (j < 0) j += n;  // modulo n circular indexing
    
    x = S->array[j] ^ xA;  // compute next value in the state
    S->array[k++] = x;     // update new state value
    
    if (k >= n) k = 0;  // modulo n circular indexing
    S->index = k;
    
    uint32_t y = x ^ (x >> u);  // tempering 
             y = y ^ ((y << s) & b);
             y = y ^ ((y << t) & c);
    uint32_t z = y ^ (y >> l);
    
    return z; 
}

const double two_pow32_inv     = 2.3283064365386963e-10; //   1/2^32
const double two_pow32_inv_2pi = 1.4629180792671596e-09; // 2pi/2^32

#undef u
#undef m

static inline
double rng_box_muller(uint32_t x, uint32_t y)
{
	double u = (x + 0.5) * two_pow32_inv;  
	double v = (y + 0.5) * two_pow32_inv_2pi;
	return sqrt(-2.0 * log(u)) * sin(v);
}

// Returns a normally-distributed random number
double mt19937_randn_bm(MT19937State* S)
{
	uint32_t x = mt19937_uint32(S),
	         y = mt19937_uint32(S);
	return rng_box_muller(x, y);
}

// Returns a normally-distributed random number
double mt19937_randn_mp(MT19937State* S)
{
	if (!isnan(S->saved)) {
		double x = S->saved;
		S->saved = NAN;
		return x;
	}

	// Marsaglia polar method used in glibc as of 2024-09-18
	double x, y, r2;
	do {
		x = mt19937_rand(S) * 2.0 - 1.0;
		y = mt19937_rand(S) * 2.0 - 1.0;
		r2 = x * x + y * y;
	} while (r2 > 1.0 || r2 == 0.0);
	double m = sqrt(-2.0 * log(r2) / r2);
	S->saved = x * m;
	return y * m;
}
