/* Copyright 2026, Alejandro A. García <aag@zorzal.net>
 * SPDX-License-Identifier: Zlib
 */
#include "sound.h"
#include "ccommon.h"
#include "alloc.h"
#include <assert.h>
#include <math.h>
#include <string.h>

#ifndef SOUND_ALLOCATOR
#define SOUND_ALLOCATOR  g_allocator
#endif

// Assumes LSB (litle endian) system

const SndFormatAttr g_snd_format_attrs[SND_FMT__COUNT] = {
	{ 0, "null", 0 },
	{ 4, "f32",  0x8120 },
	{ 2, "s16",  0x8010 },
	{ 1, "u8",   0x0008, 128 },
};

static inline
int16_t f32sys_to_s16sys(float v) {
	return (v >= 1.0f) ? 32767 : (v <= -1.0f) ? -32768 : (int16_t)(v * 32767.0f);
}

static inline
uint8_t f32sys_to_u8(float v) {
	return (v >= 1.0f) ? 255 : (v <= -1.0f) ? 0 : (uint8_t)((v + 1.0f) * 127.0f);
}

static inline
float s16sys_to_f32sys(int16_t v) {
	return (float)v * (1.0f / 32768.0f);
}

static inline
float u8_to_f32sys(uint8_t v) {
	return (float)v * (1.0f / 128.0f) - 1.0f;
}

void snd_free(Sound* S)
{
	if (S->data && S->flags & SND_F_OWN_MEM)
		alloc_free(SOUND_ALLOCATOR, S->data);
	S->data = NULL;
	S->len = 0;
	S->flags = 0;
}

void snd_resize(Sound* S, unsigned len, unsigned ch, unsigned freq, SndFormat fmt)
{
	if (!ch) ch = S->ch;
	if (!freq) freq = S->freq;
	if (!fmt) fmt = S->format;

	const SndFormatAttr * fa = snd_format_attr(fmt);
	if (len == 0 || fa->size == 0) {
		snd_free(S);
		return;
	}

	if (!(S->flags & SND_F_OWN_MEM))
		S->data = NULL;
	
	unsigned n_byte = len * ch * fa->size;
	S->data = alloc_realloc(SOUND_ALLOCATOR, S->data, n_byte);

	S->len = len;
	S->ch = ch;
	S->ss = fa->size;
	S->freq = freq;
	S->format = fmt;
	S->flags = SND_F_OWN_MEM;
}

void snd_zero(Sound* S)
{
	const SndFormatAttr * fa = snd_format_attr(S->format);
	int s = fa->silence;
	if (S->ss == (int)fa->size) {  //Contiguous
		unsigned n_byte = S->len * S->ch * S->ss;
		memset(S->data, s, n_byte);
	} else {
		Sound snd = *S;
		const size_t sz = fa->size;
		for (unsigned s=0; s<snd.len; ++s) {
			for (unsigned c=0; c<snd.ch; ++c) {
				memset(snd.data, s, sz);
				snd.data += snd.ss;
			}
		}
	}
}

void snd_copy(Sound* dst, const Sound* src)
{
	if (dst == src) return;

	snd_resize(dst, src->len, src->ch, src->freq, src->format);

	const SndFormatAttr * fa = snd_format_attr(src->format);
	if (src->ss == (int)fa->size) {  //Contiguous
		unsigned n_byte = src->len * src->ch * src->ss;
		memcpy(dst->data, src->data, n_byte);
	} else {
		uint8_t * ddata = dst->data;
		Sound snd = *src;
		const size_t sz = fa->size;
		for (unsigned s=0; s<snd.len; ++s) {
			for (unsigned c=0; c<snd.ch; ++c) {
				memcpy(ddata, src->data, sz);
				ddata += sz;
				snd.data += snd.ss;
			}
		}
	}
}

void snd_format_convert(Sound* dst, const Sound* src, SndFormat dfmt)
{
	SndFormat sfmt = src->format;
	unsigned sss = src->ss;

	snd_resize(dst, src->len, src->ch, src->freq, dfmt);
	const uint8_t * sdata = src->data;
	uint8_t * ddata = dst->data;
	unsigned dss = dst->ss;
	unsigned len = dst->len;
	unsigned ch = dst->ch;

	float v=0;
	for (int i=(int)len-1; i>=0; --i) {
		for (unsigned c=0; c<=ch; ++c) {
			switch (sfmt) {
			case SND_FMT_NONE:
				break;
			case SND_FMT_F32:
				v = *(float*)sdata;
				break;
			case SND_FMT_S16:
				v = s16sys_to_f32sys(*(int16_t*)sdata);
				break;
			case SND_FMT_U8:
				v = u8_to_f32sys(*(uint8_t*)sdata);
				break;
			}
			switch (dfmt) {
			case SND_FMT_NONE:
				break;
			case SND_FMT_F32:
				*(float*)ddata = v;
				break;
			case SND_FMT_S16:
				*(int16_t*)ddata = f32sys_to_s16sys(v);
				break;
			case SND_FMT_U8:
				*(uint8_t*)ddata = f32sys_to_u8(v);
				break;
			}
			ddata += dss;
			sdata += sss;
		}
	}
	//Probably, not the fastest implementation...
}

#define VIEW_RANGE_FIX(N, I, L, S) \
	if (L <= 0) goto empty; \
	if (S > 0) { \
		if (I < 0) I += ((I - S + 1) / S) * S; \
		else if (I >= N) goto empty; \
		if (I + L * S > N) { \
			L = (N - I) / S; \
			if (L == 0) goto empty; \
		} \
	} else if (S < 0) { \
		if (I < 0) goto empty; \
		else if (I >= N) I = (N / S) * S; \
		if (I + L * S < 0) { \
			L = -I / S; \
			if (L == 0) goto empty; \
		} \
	} else { \
		if (I < 0 || I >= N) goto empty; \
	} \
	assert( 0 <= I && I < N ); \
	assert( 0 <= (I+L*S) && (I+L*S) <= N );

void snd_view_make(Sound* dst, const Sound* src,
	int s0, int sl, int ss,		//Samples start, length, step
	int c0, int cl, int cs)		//Channels start, length, step
{
	if (dst->flags & SND_F_OWN_MEM) snd_free(dst);
	
	int len = src->len, ch = src->ch;
	VIEW_RANGE_FIX(len, s0, sl, ss);
	VIEW_RANGE_FIX(ch , c0, cl, cs);
	
	// There is only one stride in Sound and not all combinations are possible.
	if (cs != 1) goto empty;
	if (cl != 1 && cl != ch) goto empty;
	if (c0 != 0 || cl != ch) {
		if (ss != 1 || cl != 1) goto empty;
		cs = ch;
	}

	dst->data = SND_INDEX(*src, s0, c0);
	dst->len = sl;
	dst->ch = cl;
	dst->ss = src->ss * ss * cs;
	dst->freq = src->freq;  //TODO: adjust if |ss|>1 ?
	dst->format = src->format;
	dst->flags = 0;
	return;

empty:
	*dst = (Sound){0};
}

/* Code for generic point-wise operations */

#define SND_OP1(SND, EXPR) do { \
	Sound snd_ = (SND); \
	switch (snd_.format) { \
	case SND_FMT_NONE: \
		break; \
	case SND_FMT_F32: { \
		for (unsigned i=0; i<snd_.len; ++i) { \
			for (unsigned c=0; c<snd_.ch; ++c) { \
				float v = *(float*) snd_.data; \
				v = (EXPR); \
				*(float*)snd_.data = v; \
				snd_.data += snd_.ss; \
			} \
		} \
		break; \
	} \
	case SND_FMT_S16: { \
		for (unsigned i=0; i<snd_.len; ++i) { \
			for (unsigned c=0; c<snd_.ch; ++c) { \
				float v = s16sys_to_f32sys(*(int16_t*)snd_.data); \
				v = (EXPR); \
				*(int16_t*)snd_.data = f32sys_to_s16sys(v); \
				snd_.data += snd_.ss; \
			} \
		} \
		break; \
	} \
	case SND_FMT_U8: { \
		for (unsigned i=0; i<snd_.len; ++i) { \
			for (unsigned c=0; c<snd_.ch; ++c) { \
				float v = u8_to_f32sys(*(uint8_t*)snd_.data); \
				v = (EXPR); \
				*(uint8_t*)snd_.data = f32sys_to_u8(v); \
				snd_.data += snd_.ss; \
			} \
		} \
		break; \
	} \
	} \
} while(0)

#define SND_OP2(D, A, B, EXPR) do { \
	Sound snd_ = (D); \
	Sound A_ = (A); \
	Sound B_ = (B); \
	switch (snd_.format) { \
	case SND_FMT_NONE: \
		break; \
	case SND_FMT_F32: { \
		for (unsigned i=0; i<snd_.len; ++i) { \
			for (unsigned c=0; c<snd_.ch; ++c) { \
				const float a = *(float*)A_.data; \
				const float b = *(float*)B_.data; \
				const float v = (EXPR); \
				*(float*)snd_.data = v; \
				snd_.data += snd_.ss; \
				A_.data += A_.ss; \
				B_.data += B_.ss; \
			} \
		} \
		break; \
	} \
	case SND_FMT_S16: { \
		for (unsigned i=0; i<snd_.len; ++i) { \
			for (unsigned c=0; c<snd_.ch; ++c) { \
				const float a = s16sys_to_f32sys(*(int16_t*)A_.data); \
				const float b = s16sys_to_f32sys(*(int16_t*)B_.data); \
				const float v = (EXPR); \
				*(int16_t*)snd_.data = f32sys_to_s16sys(v); \
				snd_.data += snd_.ss; \
				A_.data += A_.ss; \
				B_.data += B_.ss; \
			} \
		} \
		break; \
	} \
	case SND_FMT_U8: { \
		for (unsigned i=0; i<snd_.len; ++i) { \
			for (unsigned c=0; c<snd_.ch; ++c) { \
				const float a = u8_to_f32sys(*(uint8_t*)A_.data); \
				const float b = u8_to_f32sys(*(uint8_t*)B_.data); \
				const float v = (EXPR); \
				*(uint8_t*)snd_.data = f32sys_to_u8(v); \
				snd_.data += snd_.ss; \
				A_.data += A_.ss; \
				B_.data += B_.ss; \
			} \
		} \
		break; \
	} \
	} \
} while(0)

static inline
float square_wave(float x)
{
	float dummy;
	return copysignf(1.0, modff(x, &dummy) * 2.0f - 1.0f);
}

static inline
float triangle_wave(float x)
{
	float dummy;
	return fabsf(modff(x, &dummy) * 4.0f - 2.0f) - 1.0f;
}

static inline
float sawtooth_wave(float x)
{
	float dummy;
	return modff(x, &dummy) * 2.0f - 1.0f;
}

void snd_wave_add(Sound* S, SndWaveType type, double freq, double amp,
	double phase)
{
	const float a = amp;
	switch (type) {
	case SND_WT_NONE:
		break;
	case SND_WT_SINE: {
		const float p = phase * 2.0 * M_PI;
		const float w = 2.0 * M_PI * freq / S->freq;
		SND_OP1(*S, v + sinf(p + w * i) * a);
		break;
	}
	case SND_WT_SQUARE: {
		const float p = phase + 0.5;  //Starts at zero value
		const float w = freq / S->freq;
		SND_OP1(*S, v + square_wave(p + w * i) * a);
		break;
	}
	case SND_WT_TRIANGLE: {
		const float p = phase + 0.5;  //Starts at zero value
		const float w = freq / S->freq;
		SND_OP1(*S, v + triangle_wave(p + w * i) * a);
		break;
	}
	case SND_WT_SAWTOOTH: {
		const float p = phase + 0.5;  //Starts at zero value
		const float w = freq / S->freq;
		SND_OP1(*S, v + sawtooth_wave(p + w * i) * a);
		break;
	}
	case SND_WT_SINE_PM_SINE: {
		const float p = phase * 2.0 * M_PI;
		const float w = 2.0 * M_PI * freq / S->freq;
		SND_OP1(*S, v + sinf(p + w * i + sinf(p + w * i)) * a);
		break;
	}
	}
}

void snd_fade(Sound* S, SndCurveType type, double m1, double m2, double sigma)
{
	switch (type) {
	case SND_CT_LINEAR: {
		const float m = (m2 - m1) / S->len;
		const float b = m1;
		SND_OP1(*S, v * (b + m * i) );
		break;
	}
	case SND_CT_EXP: {
		IFNPOSSET(sigma, 5);
		const float b = m2;
		const float m = m1 - m2;
		const float w = -sigma / S->len;
		SND_OP1(*S, v * (b + m * expf(w * i)) );
		break;
	}
	}
}

void snd_adsr(Sound* S, double t_attack, double t_decay, double t_sustain,
	double m_attack, double m_decay, double m_sustain, SndCurveType ctype)
{
	int n_a = t_attack  * S->freq + 0.5;
	int n_d = t_decay   * S->freq + 0.5;
	int n_s = t_sustain * S->freq + 0.5;
	int n_r = (int)S->len - n_s - n_d - n_a;

	int i0=0;
	if (n_a > 0) {
		Sound vw = snd_slice_get(S, i0, i0+n_a);
		snd_fade(&vw, ctype, 0, m_attack, 5);
		i0 += n_a;
	}
	if (n_d > 0) {
		Sound vw = snd_slice_get(S, i0, i0+n_d);
		snd_fade(&vw, ctype, m_attack, m_decay, 5);
		i0 += n_d;
	}
	if (n_s > 0) {
		Sound vw = snd_slice_get(S, i0, i0+n_s);
		snd_fade(&vw, ctype, m_decay, m_sustain, 5);
		i0 += n_s;
	}
	if (n_r > 0) {
		Sound vw = snd_slice_get(S, i0, i0+n_r);
		snd_fade(&vw, ctype, m_sustain, 0, 5);
		i0 += n_r;
	}
}

void snd_add(Sound* dst, const Sound* src, double dmul, double smul)
{
	if (!snd_specs_equal(dst, src) || dst->len != src->len)
		return;
	
	float m_d = dmul, m_s = smul;
	SND_OP2(*dst, *dst, *src, a * m_d + b * m_s);
}

/* Random generation */

// *Really* minimal PCG32 code / (c) 2014 M.E. O'Neill / pcg-random.org
// Licensed under Apache License 2.0 (NO WARRANTY, etc. see website)
typedef struct { uint64_t state;  uint64_t inc; } pcg32_random_t;
#define PCG32_INITIALIZER   { 0x853c49e6748fea9bULL, 0xda3e39cb94b95bdbULL }

static inline
uint32_t pcg32_random_r(pcg32_random_t* rng)
{
    uint64_t oldstate = rng->state;
    // Advance internal state
    rng->state = oldstate * 6364136223846793005ULL + (rng->inc|1);
    // Calculate output function (XSH RR), uses old state for max ILP
    uint32_t xorshifted = ((oldstate >> 18u) ^ oldstate) >> 27u;
    uint32_t rot = oldstate >> 59u;
    return (xorshifted >> rot) | (xorshifted << ((-rot) & 31));
}
// End of code block

static inline
float rand_f32(void) {
	static pcg32_random_t rng = PCG32_INITIALIZER;
	uint32_t u32 = pcg32_random_r(&rng);
	return ((float)u32 * (1.0f / 0xffffffff));  //[0, 1]
}

// Weighted Stochastic Voss-McCartney Algorithm
// from https://www.ridgerat-tech.us/pink/pinkalg.htm
// (c) Larry Trammell, 2016-2020 
// Creative Commons Attribution 4.0 International License
static const float pink_gen_av[3] = {0.24390449, 0.31582856, 0.44026694};
static const float pink_gen_pv[3] = {0.31878, 0.77686, 0.97785};
typedef struct { float v[3]; } pink_gen_t;

static inline
float pink_gen(pink_gen_t* state) {
	float out = 0;
	float rv = rand_f32();
	for (unsigned i=0; i<COUNTOF(state->v); ++i) {
		if (rv > pink_gen_pv[i])
			state->v[i] = pink_gen_av[i] * (rand_f32() * 2.0f - 1.0f);
		out += state->v[i];
	}
	return out;
}
// End of code block

#define PINK_GEN_INITIALIZER  {{0.2031007f, -0.07548388f, 0.2611043f}}

static inline
float pink_f32(void) {
	static pink_gen_t rng = PINK_GEN_INITIALIZER;
	return pink_gen(&rng);
}

void snd_noise_add(Sound* S, SndNoiseType type, double amp)
{
	const float a = amp;
	switch (type) {
	case SND_NT_NONE:
		break;
	case SND_NT_WHITE: {
		SND_OP1(*S, v + (rand_f32() * 2.0f - 1.0f) * a);
		break;
	}
	case SND_NT_PINK: {
		SND_OP1(*S, v + pink_f32() * a);
		break;
	}
	}
}

/* Utility */

#define HEADER_PUSH_STR(S) do { \
	memcpy(cur, S, sizeof(S)-1); \
	cur += sizeof(S)-1; \
} while(0)
#define HEADER_PUSH_U32(N) do { \
	*((uint32_t*)cur) = (N); \
	cur += 4; \
} while(0)
#define HEADER_PUSH_U16(N) do { \
	*((uint16_t*)cur) = (N); \
	cur += 2; \
} while(0)

void snd_wav_header_fill(const Sound* S, uint8_t header[SND_WAV_HEADER_SIZE])
{
	const SndFormatAttr * fa = snd_format_attr(S->format);
	unsigned n_byte = S->len * S->ch * fa->size;

	uint8_t * cur = header;
	
	HEADER_PUSH_STR("RIFF");
	HEADER_PUSH_U32(SND_WAV_HEADER_SIZE + n_byte);
	HEADER_PUSH_STR("WAVE");

	HEADER_PUSH_STR("fmt ");
	HEADER_PUSH_U32(16);
	HEADER_PUSH_U16( S->format == SND_FMT_F32 ? 3 : 1 );
	HEADER_PUSH_U16( S->ch );
	HEADER_PUSH_U32(S->freq);
	HEADER_PUSH_U32(S->freq * S->ch * fa->size);  //bytes per sec
	HEADER_PUSH_U16(S->ch * fa->size);  //bytes per bloc
	HEADER_PUSH_U16(fa->size * 8);  //bits per sample
	
	HEADER_PUSH_STR("data");
	HEADER_PUSH_U32(n_byte);
	assert(cur == header + SND_WAV_HEADER_SIZE);
}
