from requests import get
from datetime import datetime, timedelta
import time

TERMS = [
    '"Noncommunicable+Diseases"',       # NCDs (All)
    '"Diabetes+Mellitus"',              # Diabetes (type 1 or 2)
    '"Neoplasms"',                      # Cancer
    '"Respiratory+Tract+Diseases"',     # Chronic respiratory disease
    '"Cardiovascular+Diseases"',        # Cardiovascular diseases
    '"Mental+Health"',                  # Mental Health
    '"Diabetes+Mellitus%2C+Type+1"',    # Diabetes type 1
    '"Diabetes+Mellitus%2C+Type+2"'     # Diabetes type 2
]

INTERVALS = [
    "day",
    "week",
    "month"
]

def get_count_for_year(year, term, interval = "month"):
    current_date = datetime(year, 1, 1)

    counts = []

    while(current_date < datetime(year, 12, 31)):
        if interval == "day":
            next_date = current_date + timedelta(days=1)
        elif interval == "week":
            next_date = current_date + timedelta(weeks=1)
        elif interval == "month":
            next_date = (current_date.replace(day=28) + timedelta(days=4)).replace(day=1)

        url = f'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=pubmed&term={term}&retmode=json&mindate={current_date.strftime("%Y/%m/%d")}&maxdate={next_date.strftime("%Y/%m/%d")}&usehistory=y'
        response = get(url)
        search_res = response.json()
        counts.append(int(search_res["esearchresult"]["count"]))

        current_date = next_date
        time.sleep(1) # si plus de 3 requĂȘtes par seconde sinon adresse IP bann (normalement)

    max_count = max(counts)
    min_count = min(counts)
    avg_count = sum(counts) / len(counts)

    return {"max": max_count, "min": min_count, "avg": avg_count}

data = {}

for term in TERMS:
    data[term] = {}
    mesh = term + "[Mesh]"
    print("TERM: ", mesh)

    for interval in INTERVALS:
        print("INTERVAL: ", interval)
        counts = get_count_for_year(2024, mesh, interval)
        print(counts)
        data[term][interval] = counts

print(data)