From a5728da89d37ed994513676647ec05d64323fbb0 Mon Sep 17 00:00:00 2001
From: "abir.chebbi" <abir.chebbi@hes-so.ch>
Date: Thu, 12 Sep 2024 10:40:21 +0200
Subject: [PATCH] adjust the creation of the vectorDB

---
 Part 1/create-S3-and-put-docs.py |  34 +++---
 Part 1/create-vector-db.py       | 173 ++++++++++++++++---------------
 Part 1/main.py                   |  48 +++++----
 Part 2/main.py                   |   1 +
 4 files changed, 135 insertions(+), 121 deletions(-)

diff --git a/Part 1/create-S3-and-put-docs.py b/Part 1/create-S3-and-put-docs.py
index 15eab79..19b310f 100644
--- a/Part 1/create-S3-and-put-docs.py	
+++ b/Part 1/create-S3-and-put-docs.py	
@@ -1,22 +1,18 @@
 import boto3
 import os
+import argparse
 
-LOCAL_DIR = "pdfs"
-BUCKET_NAME = 'cloud-lecture-nabil-2024-25'
 
-# Initiate S3 client
-s3_client = boto3.client('s3')
+def create_bucket(s3_client, bucket_name):
+    """ Create an S3 bucket """
+    print("Creating Bucket")
+    response = s3_client.create_bucket(Bucket=bucket_name)
+    print(response)
+    print()
 
-# Create S3 Bucket
-print("Creating Bucket")
-response = s3_client.create_bucket(
-    Bucket=BUCKET_NAME,
-)
-print(response)
-print()
 
 # Function to write files to S3
-def write_files(directory, bucket):
+def write_files(s3_client, directory, bucket):
     for filename in os.listdir(directory):
         if filename.endswith(".pdf"):  # Check if the file is a PDF
             file_path = os.path.join(directory, filename)
@@ -28,8 +24,16 @@ def write_files(directory, bucket):
                     Key=filename
                 )
                 print(f"{filename} uploaded successfully.")
+                
+def main(bucket_name, local_dir):
+    s3_client = boto3.client('s3')
+    create_bucket(s3_client, bucket_name)
+    write_files(s3_client, local_dir, bucket_name)
 
-# Upload PDF files to S3 bucket
-print("Writing Items to Bucket")
-write_files(LOCAL_DIR, BUCKET_NAME)
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Upload PDF files to an S3 bucket")
+    parser.add_argument("bucket_name", help="The name of the S3 bucket to which the files will be uploaded")
+    parser.add_argument("LOCAL_DIR", help="The name of the folder to put the pdf files")
+    args = parser.parse_args()
+    main(args.bucket_name, args.LOCAL_DIR)
 
diff --git a/Part 1/create-vector-db.py b/Part 1/create-vector-db.py
index c411136..68aedb8 100644
--- a/Part 1/create-vector-db.py	
+++ b/Part 1/create-vector-db.py	
@@ -2,30 +2,30 @@
 import boto3
 import botocore
 import time
+import argparse
 
 
 client = boto3.client('opensearchserverless')
 #service = 'aoss'
-Vector_store_name='test-nabil'
 
-def createEncryptionPolicy(client):
-    """Creates an encryption policy that matches all collections beginning with test"""
+def createEncryptionPolicy(client,policy_name, collection_name):
+    """Creates an encryption policy for the specified collection."""
     try:
         response = client.create_security_policy(
-            description='Encryption policy for test collections',
-            name='test-policy',
-            policy="""
-                {
-                    \"Rules\":[
-                        {
-                            \"ResourceType\":\"collection\",
-                            \"Resource\":[
-                                \"collection\/test*\"
+            description=f'Encryption policy for {collection_name}',
+            name=policy_name,
+            policy=f"""
+                {{
+                    \"Rules\": [
+                        {{
+                            \"ResourceType\": \"collection\",
+                            \"Resource\": [
+                                \"collection/{collection_name}\"
                             ]
-                        }
+                        }}
                     ],
-                    \"AWSOwnedKey\":true
-                }
+                    \"AWSOwnedKey\": true
+                }}
                 """,
             type='encryption'
         )
@@ -39,27 +39,27 @@ def createEncryptionPolicy(client):
             raise error
 
 
-def createNetworkPolicy(client):
-    """Creates a network policy that matches all collections beginning with test"""
+def createNetworkPolicy(client,policy_name,collection_name):
+    """Creates a network policy for the specified collection."""
     try:
         response = client.create_security_policy(
-            description='Network policy for Test collections',
-            name='test-policy',
-            policy="""
-                [{
-                    \"Description\":\"Public access for Test collection\",
-                    \"Rules\":[
-                        {
-                            \"ResourceType\":\"dashboard\",
-                            \"Resource\":[\"collection\/test*\"]
-                        },
-                        {
-                            \"ResourceType\":\"collection\",
-                            \"Resource\":[\"collection\/test*\"]
-                        }
+            description=f'Network policy for {collection_name}',
+            name=policy_name,
+            policy=f"""
+                [{{
+                    \"Description\": \"Public access for {collection_name}\",
+                    \"Rules\": [
+                        {{
+                            \"ResourceType\": \"dashboard\",
+                            \"Resource\": [\"collection/{collection_name}\"]                            
+                        }},
+                        {{
+                            \"ResourceType\": \"collection\",
+                            \"Resource\": [\"collection/{collection_name}\"]                            
+                        }}
                     ],
-                    \"AllowFromPublic\":true
-                }]
+                    \"AllowFromPublic\": true
+                }}]
                 """,
             type='network'
         )
@@ -73,65 +73,62 @@ def createNetworkPolicy(client):
             raise error
 
 
-def createAccessPolicy(client):
-    """Creates a data access policy that matches all collections beginning with test"""
+def createAccessPolicy(client, policy_name, collection_name, IAM_USER):
+    """Creates a data access policy for the specified collection."""
     try:
+        policy_content = f"""
+        [
+            {{
+                "Rules": [
+                    {{
+                        "Resource": ["collection/{collection_name}"],
+                        "Permission": [
+                            "aoss:CreateCollectionItems",
+                            "aoss:DeleteCollectionItems",
+                            "aoss:UpdateCollectionItems",
+                            "aoss:DescribeCollectionItems"
+                        ],
+                        "ResourceType": "collection"
+                    }},
+                    {{
+                        "Resource": ["index/{collection_name}/*"],
+                        "Permission": [
+                            "aoss:CreateIndex",
+                            "aoss:DeleteIndex",
+                            "aoss:UpdateIndex",
+                            "aoss:DescribeIndex",
+                            "aoss:ReadDocument",
+                            "aoss:WriteDocument"
+                        ],
+                        "ResourceType": "index"
+                    }}
+                ],
+                "Principal": ["arn:aws:iam::352909266144:user/{IAM_USER}"]
+            }}
+        ]
+        """
         response = client.create_access_policy(
-            description='Data access policy for Test collections',
-            name='test-policy',
-            policy="""
-                [{
-                    \"Rules\":[
-                        {
-                            \"Resource\":[
-                                \"index\/test*\/*\"
-                            ],
-                            \"Permission\":[
-                                \"aoss:CreateIndex\",
-                                \"aoss:DeleteIndex\",
-                                \"aoss:UpdateIndex\",
-                                \"aoss:DescribeIndex\",
-                                \"aoss:ReadDocument\",
-                                \"aoss:WriteDocument\"
-                            ],
-                            \"ResourceType\": \"index\"
-                        },
-                        {
-                            \"Resource\":[
-                                \"collection\/test*\"
-                            ],
-                            \"Permission\":[
-                                \"aoss:CreateCollectionItems\",
-                                \"aoss:DeleteCollectionItems\",
-                                \"aoss:UpdateCollectionItems\",
-                                \"aoss:DescribeCollectionItems\"
-                            ],
-                            \"ResourceType\": \"collection\"
-                        }
-                    ],
-                    \"Principal\":[
-                        \"arn:aws:iam::768034348959:user/AbirChebbi\"
-                    ]
-                }]
-                """,
+            description=f'Data access policy for {collection_name}',
+            name=policy_name,
+            policy=policy_content,
             type='data'
         )
         print('\nAccess policy created:')
         print(response)
     except botocore.exceptions.ClientError as error:
         if error.response['Error']['Code'] == 'ConflictException':
-            print(
-                '[ConflictException] An access policy with this name already exists.')
+            print('[ConflictException] An access policy with this name already exists.')
         else:
             raise error
+
         
 
         
-def waitForCollectionCreation(client):
+def waitForCollectionCreation(client,collection_name):
     """Waits for the collection to become active"""
-    time.sleep(40)
+    time.sleep(30)
     response = client.batch_get_collection(
-            names=['test1'])
+            names=[collection_name])
     print('\nCollection successfully created:')
     print(response["collectionDetails"])
     # Extract the collection endpoint from the response
@@ -140,16 +137,22 @@ def waitForCollectionCreation(client):
     return final_host
 
 
-def main():
-
-    createEncryptionPolicy(client)
-    createNetworkPolicy(client)
-    createAccessPolicy(client)
-    collection = client.create_collection(name=Vector_store_name,type='VECTORSEARCH')
-    ENDPOINT= waitForCollectionCreation(client)
+def main(collection_name,IAM_USER):
+    encryption_policy_name = f'{collection_name}-encryption-policy'
+    network_policy_name = f'{collection_name}-network-policy'
+    access_policy_name = f'{collection_name}-access-policy'
+    createEncryptionPolicy(client, encryption_policy_name, collection_name)
+    createNetworkPolicy(client, network_policy_name, collection_name)
+    createAccessPolicy(client, access_policy_name, collection_name,IAM_USER)
+    collection = client.create_collection(name=collection_name,type='VECTORSEARCH')
+    ENDPOINT= waitForCollectionCreation(client,collection_name)
 
     print("Collection created successfully:", collection)
     print("Collection ENDPOINT:", ENDPOINT)
 
 if __name__== "__main__":
-    main()
\ No newline at end of file
+    parser = argparse.ArgumentParser(description="Create collection")
+    parser.add_argument("collection_name", help="The name of the collection")
+    parser.add_argument("iam_user", help="The iam user")
+    args = parser.parse_args()
+    main(args.collection_name,args.iam_user)
diff --git a/Part 1/main.py b/Part 1/main.py
index 36feb7e..ab8850d 100644
--- a/Part 1/main.py	
+++ b/Part 1/main.py	
@@ -8,41 +8,30 @@ from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
 from langchain_community.vectorstores import OpenSearchVectorSearch
 import uuid
 import json
+import argparse
 
 
 ##  Local directory for storing PDF files
 LOCAL_DIR = "pdfs" 
-index_name = "cloud_lecture"
+
 
 
 ## S3_client
 s3_client = boto3.client('s3')
-## Bucket name where documents are stored
-BUCKET_NAME = "cloud-lecture-2023"
 
 ## Bedrock client
 bedrock_client = boto3.client(service_name="bedrock-runtime")
 
 
 ## Configuration for AWS authentication and OpenSearch client
-credentials = boto3.Session().get_credentials()
+credentials = boto3.Session(profile_name='master-group-14').get_credentials()
 awsauth = AWSV4SignerAuth(credentials, 'us-east-1', 'aoss')
 
-## Vector DB endpoint
-host= 'j6phg34iv0f2rlvxwawd.us-east-1.aoss.amazonaws.com'
 
-## Opensearch Client
-OpenSearch_client = OpenSearch(
-    hosts=[{'host': host, 'port': 443}],
-    http_auth=awsauth,
-    use_ssl=True,
-    verify_certs=True,
-    connection_class=RequestsHttpConnection,
-    
-)
+
 
 ## Create Index in Opensearch
-def create_index(index_name):
+def create_index(client,index_name):
     indexBody = {
         "settings": {
             "index.knn": True
@@ -62,7 +51,7 @@ def create_index(index_name):
     }
 
     try:
-        create_response = OpenSearch_client.indices.create(index_name, body=indexBody)
+        create_response = client.indices.create(index_name, body=indexBody)
         print('\nCreating index:')
         print(create_response)
     except Exception as e:
@@ -101,6 +90,7 @@ def generate_embeddings(bedrock_client, chunks):
 
 # Store generated embeddings into an OpenSearch index.
 def store_embeddings(embeddings, texts, meta_data, host, awsauth, index_name):  
+    
     docsearch = OpenSearchVectorSearch.from_embeddings(
         embeddings,
         texts,
@@ -137,14 +127,25 @@ def generate_store_embeddings(bedrock_client, chunks,awsauth,index_name):
 
 
 ## main 
-def main():
+def main(bucket_name, endpoint,index_name):
+
+    ## Opensearch Client
+    OpenSearch_client = OpenSearch(
+        hosts=[{'host': endpoint, 'port': 443}],
+        http_auth=awsauth,
+        use_ssl=True,
+        verify_certs=True,
+        connection_class=RequestsHttpConnection,
+        
+    )
 
-    download_documents(BUCKET_NAME,LOCAL_DIR)
+    download_documents(bucket_name,LOCAL_DIR)
     loader= PyPDFDirectoryLoader(LOCAL_DIR)
     docs = loader.load()
     print(docs[1])
     chunks = split_text(docs, 1000, 100)
     print(chunks[1])
+    create_index(OpenSearch_client,index_name)
     embeddings= generate_embeddings(bedrock_client, chunks)
     print(embeddings[1])
     texts = [chunk.page_content for chunk in chunks]
@@ -152,7 +153,7 @@ def main():
     meta_data = [{'source': chunk.metadata['source'], 'page': chunk.metadata['page'] + 1} for chunk in chunks]
     print(embeddings[1])
     print(meta_data[1])
-    store_embeddings(embeddings, texts, meta_data ,host, awsauth,index_name)
+    store_embeddings(embeddings, texts, meta_data ,endpoint, awsauth,index_name)
 
 
    
@@ -163,4 +164,9 @@ def main():
 
 
 if __name__== "__main__":
-    main()
+    parser = argparse.ArgumentParser(description="Process PDF documents and store their embeddings.")
+    parser.add_argument("bucket_name", help="The S3 bucket name where documents are stored")
+    parser.add_argument("endpoint", help="The OpenSearch service endpoint")
+    parser.add_argument("index_name", help="The name of the OpenSearch index")
+    args = parser.parse_args()
+    main(args.bucket_name, args.endpoint, args.index_name)
diff --git a/Part 2/main.py b/Part 2/main.py
index 713a2bd..0a9d201 100644
--- a/Part 2/main.py	
+++ b/Part 2/main.py	
@@ -107,6 +107,7 @@ def main():
         st.session_state.chat_history.append({"role": "user", "content": user_prompt})
         # Generate and display answer
         print(user_prompt)
+ 
         embed_question= get_embedding(user_prompt,bedrock_client)
         print(embed_question)
         sim_results = similarity_search(embed_question, index_name)
-- 
GitLab