| |
| import gradio as gr |
| from langchain import hub |
| from langchain_chroma import Chroma |
| from langchain_core.output_parsers import StrOutputParser |
| from langchain_core.runnables import RunnablePassthrough |
| from langchain_mistralai import MistralAIEmbeddings |
| from langchain_community.embeddings import HuggingFaceInstructEmbeddings |
| from langchain_text_splitters import RecursiveCharacterTextSplitter |
| from langchain_mistralai import ChatMistralAI |
| from langchain_community.document_loaders import PyPDFLoader |
| import requests |
| from pathlib import Path |
| from langchain_community.document_loaders import WebBaseLoader |
| from langchain_community.retrievers import ArxivRetriever |
| import bs4 |
| from langchain_core.rate_limiters import InMemoryRateLimiter |
| from urllib.parse import urljoin |
|
|
| rate_limiter = InMemoryRateLimiter( |
| requests_per_second=0.1, |
| check_every_n_seconds=0.01, |
| max_bucket_size=10, |
| ) |
|
|
| |
| urlsfile = open("urls.txt") |
| urls = urlsfile.readlines() |
| urls = [url.replace("\n","") for url in urls] |
| urlsfile.close() |
|
|
| |
| loader = WebBaseLoader(urls) |
| docs = loader.load() |
|
|
| |
| arxivfile = open("arxiv.txt") |
| arxivs = arxivfile.readlines() |
| arxivs = [arxiv.replace("\n","") for arxiv in arxivs] |
| arxivfile.close() |
|
|
| retriever = ArxivRetriever( |
| load_max_docs=2, |
| get_ful_documents=True, |
| ) |
|
|
| for arxiv in arxivs: |
| doc = retriever.invoke(arxiv) |
| doc[0].metadata['Published'] = str(doc[0].metadata['Published']) |
| docs.append(doc[0]) |
| |
|
|
| def format_docs(docs): |
| return "\n\n".join(doc.page_content for doc in docs) |
|
|
| def RAG(llm, docs, embeddings): |
|
|
| |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
| splits = text_splitter.split_documents(docs) |
|
|
| |
| vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings) |
|
|
| |
| retriever = vectorstore.as_retriever() |
|
|
| |
| prompt = hub.pull("rlm/rag-prompt") |
|
|
| |
| rag_chain = ( |
| {"context": retriever | format_docs, "question": RunnablePassthrough()} |
| | prompt |
| | llm |
| | StrOutputParser() |
| ) |
|
|
| return rag_chain |
|
|
| |
| llm = ChatMistralAI(model="mistral-large-latest", rate_limiter=rate_limiter) |
|
|
| |
| embed_model = "sentence-transformers/multi-qa-distilbert-cos-v1" |
| |
| |
| embeddings = HuggingFaceInstructEmbeddings(model_name=embed_model) |
| |
|
|
| |
| rag_chain = RAG(llm, docs, embeddings) |
|
|
| def handle_prompt(message, history): |
| try: |
| |
| out="" |
| for chunk in rag_chain.stream(message): |
| out += chunk |
| yield out |
| except: |
| raise gr.Error("Requests rate limit exceeded") |
|
|
| greetingsmessage = "Hi, I'm ChangBot, a chat bot here to assist you with any question related to Chang's research. I'm in pre-alpha stage, so please be patient." |
| example_questions = [ |
| "Tell me more about SimBIG", |
| "How can you constrain neutrino mass with galaxies?", |
| "What is the DESI BGS?", |
| "What is SEDflow?", |
| "What are normalizing flows?" |
| ] |
|
|
| demo = gr.ChatInterface(handle_prompt, type="messages", title="ChangBot", examples=example_questions, theme=gr.themes.Soft(), description=greetingsmessage) |
|
|
| demo.launch() |