1 module wafom.precomputation;
2 
3 import digitalnet.implementation;
4 import std.algorithm, std.math, std.functional;
5 public import std.math : lg = log2;
6 
7 real WAFOM(U, Size)(DigitalNet!(U, Size) P, real c = 1, real exponent = 1)
8 {
9 	return (P + new U[P.dimensionR]).WAFOMimpl(getWAFOMCalculator!U(P.precision, c, exponent));
10 }
11 
12 real RMSWAFOM(U, Size)(DigitalNet!(U, Size) P, real c = 1)
13 {
14 	return (P + new U[P.dimensionR]).WAFOMimpl(getWAFOMCalculator!U(P.precision, c, 2));
15 }
16 
17 private real WAFOMimpl(U, Size, size_t chunkSize)(ShiftedDigitalNet!(U, Size) P, WAFOMCalculator!(U, chunkSize) wafomCalculator)
18 {
19 	if (P.bisectable)
20 	{
21 		auto Q = P.bisect;
22 		return (Q[0].WAFOMimpl(wafomCalculator) + Q[1].WAFOMimpl(wafomCalculator)) / 2;
23 	}
24 	real ret = 0;
25 	foreach (X; P)
26 		ret += wafomCalculator.WAFOMIntegrand(X);
27 	return ret * exp2(-cast(ptrdiff_t)P.dimensionF2);
28 }
29 
30 alias getWAFOMCalculator(U, size_t chunkSize = 8) = memoize!(_getWAFOMCalculator!(U, chunkSize));
31 
32 auto _getWAFOMCalculator(U, size_t chunkSize)(size_t n, real c, real exponent)
33 {
34 	return new WAFOMCalculator!(U, chunkSize)(n, c, exponent);
35 }
36 
37 private class WAFOMCalculator(U, size_t chunkSize)
38 {
39 	this (in size_t n, real c, real exponent)
40 	{
41 		auto t = new real[2][n];
42 		foreach (j, ref u; t)
43 		{
44 			immutable eps = exp2(exponent * (c - 1 - (n - j)));
45 			u[0] = 1 + eps;
46 			u[1] = 1 - eps;
47 		}
48 		real[1 << chunkSize][] memo;
49 		while (t.length)
50 		{
51 			memo.length += 1;
52 			foreach (i, ref x; memo[$ - 1])
53 			{
54 				x = 1;
55 				foreach (j, u; t[0..min($, chunkSize)])
56 					x *= u[i >> j & 1];
57 			}
58 			t = t[min($, chunkSize)..$];
59 		}
60 		this.m = memo.idup;
61 	}
62 	real WAFOMIntegrand(U[] X)
63 	{
64 		real ret = 1;
65 		foreach (x; X)
66 			ret *= WAFOMIntegrand(x);
67 		return ret - 1;
68 	}
69 private:
70 	real WAFOMIntegrand(U x)
71 	{
72 		real ret = 1;
73 		foreach (row; m)
74 		{
75 			ret *= row[x & mask];
76 			x >>= chunkSize;
77 		}
78 		return ret;
79 	}
80 	enum mask = U.max >> ((U.sizeof << 3) - chunkSize);
81 	immutable real[1 << chunkSize][] m;
82 }