Saturday, 3 May 2025

RAG Time with LLM

 During the Easter long weekend, I wrote a RAG agent that specialises in database queries that can generate SQL queries and execute them. The result was pretty good and I had a blast along the way.

The application is a typical RAG agent with tool calls. I fed a number of Confluence pages that contain the database's data dictionary to pre-process, chunk and embed into memory storage. Then this storage is used as the memory for relevant contexts when chatting to the LLM.

The only thing special in my app is that in the system prompt, I asked the LLM to make tool calls into database to query the meta-data of the tables involved in SQL queries, if the LLM is not sure about the table structure from the provided context. This results in recursive tool calls. For example, if a SQL statement joins two tables, and it is not certain of them, then the LLM would query database for each table to retrieve their structures and use that information to amend the SQL query... 

Therefore, in my ChatBot class I have a recursive method to handle tool calls:

    @staticmethod
    def process_tool_calls(response, messages=[], model=utils.CHAT_MODEL):
        if response.choices[0].finish_reason != 'tool_calls':
            print("process_tool_calls returning: finish_reason=", response.choices[0].finish_reason)
            return response
        # else
        # print("tool_calls=", response.choices[0].message.tool_calls)
        for tool_call in response.choices[0].message.tool_calls:
            print(f"tool_call id={tool_call.id}")
            result = Chat_Bot.process_tool_call(tool_call)
           
            messages.extend([
                {"role": "assistant", "content": str(result), "tool_calls": [{"id": tool_call.id, "type": "function", "function": {"name": tool_call.function.name, "arguments": tool_call.function.arguments}}]},
                {"role": "tool", "tool_call_id": tool_call.id, "content": str(result)}
            ])

        # print("messages=", messages)
        response = utils.client.chat.completions.create(
            model=model,
            messages=messages,
            tools=tools, tool_choice='auto',
            stream=False
        )
        # print("tool_calls_response:", response)
       
        return Chat_Bot.process_tool_calls(response, messages, model)

This method is called by the main chat() method

    def chat(self, prompt, model=utils.CHAT_MODEL, temperature=0.3):
        """
        Args:
            prompt (str): user prompt or question for the AI bot
            model (str): LLM model id
            temperatur (float): ranging from 0.0 to 2.0, the bigger the wilder
        """
        queries = [prompt]
        relevant_chunks = self.search_chunks(queries)
        # print("relevant_chunks: ", relevant_chunks)
       
        content = "\n".join(relevant_chunks)
       
        messages = [{"role": "system", "content": self.system_prompt}]
       
        # Add chat history if available
        if self.chat_history:
            messages.extend(self.chat_history)
       
        # Add current context and question
        messages.extend([
            {"role": "user", "content": f"Document Excerpt: {content}"},
            {"role": "user", "content": f"Question: {prompt}"}
        ])
       
        response = utils.client.chat.completions.create(
            model=model,
            messages=messages,
            temperature=temperature,
            stream=False,
            tools=tools, tool_choice='auto'
            # max_tokens=4096
        )

        # invoke tools
        follow_up_response = Chat_Bot.process_tool_calls(response, messages, model)
           
        print("follow_up_response:", follow_up_response)
        answer = follow_up_response.choices[0].message.content

        # Update chat history
        self.chat_history.extend([
            {"role": "user", "content": prompt},
            {"role": "assistant", "content": answer}
        ])
       
        # Keep only last 20 messages to prevent context from growing too large
        if len(self.chat_history) > 20:
            self.chat_history = self.chat_history[-20:]

        return answer #, final_answer