Coverage for src/app/repositories/langChain_repository.py: 100%

48 statements  

« prev     ^ index     » next       coverage.py v7.7.0, created at 2025-04-03 00:42 +0200

1from langchain_openai import ChatOpenAI 

2from langchain_core.prompts import ChatPromptTemplate 

3from langchain.chains.combine_documents import create_stuff_documents_chain 

4from langchain_core.documents import Document 

5from langchain_core.messages import trim_messages 

6from langchain_text_splitters import RecursiveCharacterTextSplitter 

7from langchain.memory import ConversationBufferMemory 

8 

9from entities.query_entity import QueryEntity 

10from entities.document_context_entity import DocumentContextEntity 

11from entities.answer_entity import AnswerEntity 

12from entities.file_entity import FileEntity 

13from entities.file_chunk_entity import FileChunkEntity 

14 

15 

16class LangChainRepository: 

17 def __init__(self, model: ChatOpenAI): 

18 

19 self.model = model 

20 self.user_memories = {} 

21 self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=2500, chunk_overlap=0) 

22 

23 def get_user_memory(self, user_id: int): 

24 """Retrieve or create memory for a specific user.""" 

25 if user_id not in self.user_memories: 

26 self.user_memories[user_id] = ConversationBufferMemory(return_messages=True) 

27 return self.user_memories[user_id] 

28 

29 def generate_answer(self, query: QueryEntity, contexts: list[DocumentContextEntity], prompt_template: str) -> AnswerEntity: 

30 """ 

31 Given a Query and a list of document contexts, perform a call to the OpenAI LLM model and get a detailed answer. 

32 

33 Args: 

34 query (QueryEntity): The query entity. 

35 contexts (list[DocumentContextEntity]): A list of document context entities. 

36 prompt_template str: the system message to llm on how should it beave 

37 

38 Returns: 

39 AnswerEntity: A detailed answer entity containing the answer given by LLM. 

40 

41 Raises: 

42 Exception: If an error occurs during the answer generation. 

43 """ 

44 if not query.get_query().strip(): 

45 raise ValueError("Query cannot be empty") 

46 

47 try: 

48 user_question = query.get_query() 

49 documents = [Document(page_content=context.get_content()) for context in contexts] 

50 

51 # Get user-specific memory 

52 memory = self.get_user_memory(query.get_user_id()) 

53 history = memory.load_memory_variables({})["history"] 

54 

55 # Trim history if needed 

56 trimmed_history = trim_messages( 

57 history, 

58 max_tokens=2000, 

59 strategy="last", 

60 include_system=True, 

61 token_counter=self.model 

62 ) 

63 

64 history_message = f"Previous conversation history: {''.join(msg.content for msg in trimmed_history) if isinstance(trimmed_history, list) else str(trimmed_history)}" 

65 

66 # **Ensure history is used in the prompt** 

67 prompt_template = ChatPromptTemplate.from_messages( 

68 [ 

69 ("system", prompt_template), 

70 ("system", history_message), 

71 ("user", "{user_question}"), 

72 ("system", "{context}") 

73 ] 

74 ) 

75 

76 chain = create_stuff_documents_chain( 

77 llm=self.model, 

78 prompt=prompt_template 

79 ) 

80 

81 answer = chain.invoke({ 

82 "user_question": user_question, 

83 "context": documents, 

84 "prompt_template": prompt_template, 

85 "history_message": history_message 

86 }) 

87 

88 # Store interaction in user-specific memory 

89 memory.save_context({"input": user_question}, {"output": answer}) 

90 

91 return AnswerEntity(answer) 

92 

93 except Exception as e: 

94 raise Exception(f"Error while generating the answer from LangChain model for user {query.get_user_id()}: " + str(e)) 

95 

96 def split_file(self, file: FileEntity) -> list[FileChunkEntity]: 

97 """ 

98 Given a file entity it splits the file in chunks of 2,5k characters. 

99  

100 Args: 

101 file (FileEntity): The file entity to split. 

102  

103 Returns: 

104 list[FileChunkEntity]: A list of file chunk entities containing the file chunks. 

105 """ 

106 

107 try: 

108 

109 file_content = file.get_file_content() 

110 if isinstance(file_content, bytes): 

111 file_content = file_content.decode('utf-8', errors='ignore') 

112 

113 

114 all_splits = self.text_splitter.split_text(file_content) 

115 

116 file_chunks = [FileChunkEntity(split, file.get_metadata()) for split in all_splits] 

117 

118 return file_chunks 

119 

120 except Exception as e: 

121 raise Exception("Error while splitting the file: " + str(e)) 

122 

123