diff --git a/Correction/heat_gpu.cpp b/Correction/heat_gpu.cpp
index ea2d248af1194da7d22639149c359e1e30619e14..fff45c00ae280fc01ace2a7d0d145a1fce3188cf 100644
--- a/Correction/heat_gpu.cpp
+++ b/Correction/heat_gpu.cpp
@@ -11,6 +11,7 @@ constexpr auto kPolicy = std::execution::par_unseq;
 
 template <typename T>
 void stepSlow(Matrix<T> &current, Matrix<T> &next) {
+    auto rows = current.rows();
     auto cols = current.cols();
 
     auto idxs_rows = std::views::iota(1, rows - 1);
@@ -61,6 +62,41 @@ void step(Matrix<T> &current, Matrix<T> &next) {
                   });
 }
 
+std::vector<int> computeIdxs(int rows, int cols) {
+    std::vector<int> idxs((rows - 2) * (cols - 2), 0);
+
+    auto iota = std::views::iota(0, (cols - 2) * (rows - 2));
+    std::transform(kPolicy,
+                   iota.begin(), iota.end(),
+                   idxs.begin(),
+                   [cols](int idx) {
+                       // We compute the 1d index to access the spans
+                       auto r = idx / (cols - 2);
+                       auto c = idx % (cols - 2);
+                       return (r + 1) * cols + c + 1;
+                   });
+    return idxs;
+}
+
+template <typename T>
+void stepPreComputedIdxs(std::vector<int> &idxs, Matrix<T> &current, Matrix<T> &next) {
+    auto up = -1 * current.cols();
+    auto down = current.cols();
+    auto left = -1;
+    auto right = 1;
+
+    std::for_each(kPolicy,
+                  idxs.begin(), idxs.end(),
+                  [cur = current.dataSpan(),
+                   nxt = next.dataSpan(),
+                   up, down, left, right](int i) {
+                      nxt[i] = 0.25 * (cur[i + left] +
+                                       cur[i + right] +
+                                       cur[i + up] +
+                                       cur[i + down]);
+                  });
+}
+
 int main(int argc, char **argv) {
     int m = std::stoi(argv[1]);
     int n = std::stoi(argv[2]);
@@ -80,8 +116,10 @@ int main(int argc, char **argv) {
     init_matrix(*Unext);
 
     auto start_iter = std::chrono::steady_clock::now();
+    // auto idxs = computeIdxs(U->rows(), U->cols());
     for (int t = 0; t < tmax; t++) {
         step(*U, *Unext);
+        // stepPreComputedIdxs(idxs, *U, *Unext);
         std::swap(U, Unext);
         if (t % 100 == 0) {
             std::cout << "t:" << t << std::endl;