Skip to content
Snippets Groups Projects
Verified Commit 93dfb9b1 authored by iliya.saroukha's avatar iliya.saroukha :first_quarter_moon:
Browse files

feat: Adam seems to be working

parent 16c5edfe
Branches
No related tags found
No related merge requests found
import numpy as np
from sympy import symbols, sin, cos, exp, sqrt, pi, lambdify, Function
from sympy import symbols, sin, sqrt, cos, exp, sqrt, pi, lambdify, Function
import pandas as pd
from typing import Callable
import matplotlib.pyplot as plt
......@@ -112,11 +112,59 @@ def nesterov_gd(f: Function, init_pt: list[float], lr: float, momentum: float)\
return df
def adam_gd(f: Function, init_pt: list[float], lr: float)\
-> float:
df = pd.DataFrame(columns=['x', 'y', 'Cost', 'NormGrad'])
partialx, partialy = callable_grad_2d(f)
f_call = callable_func(f)
beta1 = 0.9
beta2 = 0.999
m = np.array([0, 0])
v = np.array([0, 0])
x, y = init_pt
iter = 0
grad = np.array([partialx(x, y), partialy(x, y)])
while iter < 1e4:
iter += 1
grad = np.array([partialx(x, y), partialy(x, y)])
if np.linalg.norm(grad) < 1e-6:
break
m = beta1 * m + (1 - beta1) * grad
v = beta2 * v + (1 - beta2) * np.square(grad)
if iter == 0:
m_hat = np.array([0, 0])
v_hat = np.array([0, 0])
else:
m_hat = m / (1 - beta1**iter)
v_hat = v / (1 - beta2**iter)
eta = lr / (np.sqrt(v_hat) + 1e-8)
df.loc[iter] = [x, y, f_call(x, y), np.linalg.norm(grad)]
step = eta * m_hat
x -= step[0]
y -= step[1]
iter += 1
return df
if __name__ == "__main__":
x, y = symbols('x y')
# f: Function = x**2 + 6 * y**2
f: Function = 1 - exp(-10 * x**2 - y**2)
# f: Function = x**2 + 5 * y**2
# f: Function = 1 - exp(-10 * x**2 - y**2)
# f: Function = x**2 * y - 2 * x * y**3 + 3 * x * y + 4
# Rosenbrock(x, y)
......@@ -130,28 +178,30 @@ if __name__ == "__main__":
# f: Function = (x + 2 * y - 7)**2 + (2 * x + y - 5)**2
# Ackley(x, y)
# f: Function = -20.0 * exp(-0.2 * sqrt(0.5 * (x**2 + y**2))) - \
# exp(0.5 * (cos(2 * pi * x) + cos(2 * pi * y))) + exp(1) + 20
f: Function = -20.0 * exp(-0.2 * sqrt(0.5 * (x**2 + y**2))) - \
exp(0.5 * (cos(2 * pi * x) + cos(2 * pi * y))) + exp(1) + 20
f_call = callable_func(f)
LR = 1e-2
MOMENTUM = 0.9
plot_range = (2, 2)
plot_range = (10, 10)
init_pt = [1, 1]
# init_pt = [1, 1]
# init_pt = np.array([np.random.randint(-plot_range[0], plot_range[0] + 1),
# np.random.randint(-plot_range[1], plot_range[1] + 1)])
init_pt = np.array([np.random.randint(-plot_range[0], plot_range[0] + 1),
np.random.randint(-plot_range[1], plot_range[1] + 1)])
base = base_gd(f, init_pt, LR)
momentum = momentum_gd(f, init_pt, LR, MOMENTUM)
nesterov = nesterov_gd(f, init_pt, LR, MOMENTUM)
adam = adam_gd(f, init_pt, 0.7)
print(f"Base = {base.tail(1)}")
print(f"Momentum = {momentum.tail(1)}")
print(f"Nesterov = {nesterov.tail(1)}")
print(f"Adam = {adam.tail(1)}")
x = np.linspace(-plot_range[0], plot_range[0], 100)
y = np.linspace(-plot_range[1], plot_range[1], 100)
......@@ -174,6 +224,8 @@ if __name__ == "__main__":
color='cyan', label='Momentum')
ax.plot(nesterov['x'], nesterov['y'], '-o',
color='orange', label='Nesterov')
plt.plot(adam['x'], adam['y'], '-o',
color='violet', label='Adam')
ax.legend()
plt.figure(2)
......@@ -183,6 +235,8 @@ if __name__ == "__main__":
color='cyan', label='Momentum')
plt.plot(nesterov['x'], nesterov['y'], '-o',
color='orange', label='Nesterov')
plt.plot(adam['x'], adam['y'], '-o',
color='violet', label='Adam')
plt.legend()
plt.xlabel('x')
plt.ylabel('y')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment