diff --git a/doc/trimstart.py b/doc/trimstart.py
index 1da3fdcdf8c45d734b2b28f832f3069df040608c..635c90d806106b30c1ea0f05ba64683109b39b2a 100755
--- a/doc/trimstart.py
+++ b/doc/trimstart.py
@@ -1,11 +1,41 @@
 #!/usr/bin/env python3
+from argparse import ArgumentParser
+from fileinput import input
+from sys import stdout
 
-import sys
+from typing import Iterable, Iterator
 
-nonempty_line_reached = False
-for i, line in enumerate(sys.stdin):
-    if not nonempty_line_reached:
-        if line.rstrip() == "":
-            continue
-        nonempty_line_reached = True 
-    print(line, end='')
+
+def lstrip_stream(stream: Iterable[str]) -> Iterator[str]:
+    """
+    Skips the leading whitespace at the beginning of a stream of strings.
+    :param stream: Iterable[str] A buffered stream of string.
+    :return: Generator[str] An iterator that yields the data from the input stream once the first non-whitespace
+                            character is reached.
+    """
+
+    for first_nonempty_buffer in filter(bool, map(str.lstrip, stream)):
+        yield first_nonempty_buffer
+        break
+    else:
+        return
+
+    yield from stream
+
+
+def handle_args():
+    parser = ArgumentParser()
+    parser.add_argument(
+        'files',
+        metavar='FILE',
+        nargs='*',
+        help='Files to read sequentially as one stream. If no file is given, stdin is read instead.'
+    )
+    parser.parse_args()
+
+
+if __name__ == '__main__':
+    handle_args()
+    with input() as input_stream:
+        for line in lstrip_stream(input()):
+            stdout.write(line)