From 5d99da7199cd009ec492f3c568471bae92fd6afd Mon Sep 17 00:00:00 2001
From: Orestis Malaspinas <orestis.malaspinas@hesge.ch>
Date: Tue, 24 Mar 2020 14:20:50 +0100
Subject: [PATCH] added partial application of function

---
 covid/python/seir.py | 35 +++++++++++++++++++----------------
 1 file changed, 19 insertions(+), 16 deletions(-)

diff --git a/covid/python/seir.py b/covid/python/seir.py
index 6411cb3..bcb3847 100644
--- a/covid/python/seir.py
+++ b/covid/python/seir.py
@@ -1,20 +1,14 @@
 import numpy as np
 import matplotlib.pyplot as plt
+import functools
 
 swiss = np.array([2.100350058, 3.150525088, 4.900816803, 6.534422404, 10.501750292, 13.302217036, 24.970828471, 31.271878646, 39.323220537, 43.640606768, 57.292882147, 76.079346558, 100.116686114, 131.271878646, 158.576429405, 256.709451575, 256.709451575, 268.378063011, 309.218203034, 353.325554259, 453.675612602])
 swiss = np.array([18, 27, 42, 56, 90, 114, 214, 268, 337, 374, 491, 652, 858, 1125, 1359, 2200, 2200, 2300, 2650, 3028, 3888])
 
 days = np.array(range(1,len(swiss)+1))
 
-R_0 = 4.26
-R_01 = 4.26
-R_02 = 1 - 1/R_01
-
-Tinf = 7.0
-Tinc = 5.1
-N = 8000000.0
-
-def seir(y, t):
+def seir(y, t, R_0, Tinf, Tinc):
+    N = np.sum(y)
     y1 = np.zeros(4)
     y1[0] = - R_0 / Tinf * y[2] * y[0] / N
     y1[1] =  R_0 / Tinf * y[2] * y[0] / N - 1.0 / Tinc * y[1]
@@ -30,15 +24,23 @@ def rk4(F, y, t, dt):
     k4 = dt * F(y + k3, t + dt)
     return y + (k1 + 2 * (k2 + k3) + k4) / 6
 
-I0 = 0
-E0 = swiss[0]
+R_0 = 4.26
+R_01 = 4.26
+R_02 = 1 - 1/R_01
+
+Tinf = 7.0
+Tinc = 5.1
+N = 500000
+
+I0 = 214 / 0.2
+E0 = 2000
 R0 = 0
 S0 = N-E0-I0
-t0 = -Tinc
+t0 = 24
 
 # max_t = 5*days[len(swiss)-1]
 max_t = 2000
-n_steps = 1000000
+n_steps = 100000
 dt = max_t / n_steps
 
 y0 = np.array([S0, E0, I0, R0])
@@ -47,10 +49,11 @@ y_list = [y0]
 t_list = [t0]
 for i in range(0, n_steps):
     t = t_list[i] + dt
-    y1 = rk4(seir, y_list[i], t, dt)
+    foo = functools.partial(seir, R_0=R_0, Tinf=Tinf, Tinc=Tinc)
+    y1 = rk4(foo, y_list[i], t, dt)
     y_list.append(y1)
     t_list.append(t)
-    if (t >= 27 and t <= 67) or (t >= 98 and t <= 122) :
+    if (t > t0 and t <= t0 + 30) or (t >= t0 + 60 and t <= t0 + 90) :
         R_0 = R_02
     else:
         R_0 = R_01
@@ -67,7 +70,7 @@ plt.semilogy(t, p[0], 'b')
 plt.semilogy(t, p[1], 'r')
 plt.semilogy(t, p[2], 'k')
 plt.semilogy(t, p[3], 'g')
-plt.semilogy(days, swiss, 'k*')
+# plt.semilogy(days, swiss, 'k*')
 plt.legend(['S', 'E', 'I', 'R', 'swiss'])
 plt.show()
 
-- 
GitLab