> This course and its contents was distributed by Google LLC under the Apache License Version 2.0 and has been reproduced in part with modifications below.
Course is now available as a [Learn Guide](https://www.kaggle.com/learn-guide/5-day-genai).
Intro Unit – “Foundational Large Language Models & Text Generation”
- Listen to the summary [podcast episode](https://youtu.be/mQDlCZZsOyo?si=_TC7xZS4PE0ceaL5) for this unit (created by [NotebookLM](https://notebooklm.google.com/?original_referer=https://www.google.com%23&pli=1)).
- Read the [[barektain_2024|Foundational Large Language Models & Text Generation]] paper
- Unit 1 – “Prompt Engineering”
- Listen to the summary [podcast episode](https://youtu.be/F_hJ2Ey4BNc?si=oPMAbqGe1RPl0ZKN) for this unit.
- Read the [[sherman_2024|Prompt Engineering]] whitepaper
- Complete [this code lab](https://www.kaggle.com/code/erikanderson1/day-1-prompting/edit) on Kaggle.
- Watch to [Day 1 Livestream](https://www.youtube.com/watch?v=kpRyiJUUFxY).
- Unit 2: “Embeddings and Vector Stores/Databases”
- Listen to the summary [podcast episode](https://www.youtube.com/watch?v=1CC39K76Nqs) for this unit.
- Read the [[nawalgaria_2024|Embeddings & Vector Stores]] whitepaper.
- Complete these code labs on Kaggle:
1. [Build](https://www.kaggle.com/code/erikanderson1/day-2-document-q-a-with-rag/edit) a RAG question-answering system over custom documents
2. [Explore](https://www.kaggle.com/code/erikanderson1/day-2-embeddings-and-similarity-scores/edit) text similarity with embeddings
3. [Build](https://www.kaggle.com/code/erikanderson1/day-2-classifying-embeddings-with-keras/edit) a neural classification network with Keras using embeddings
- Watch [Day 2 Livestream](https://www.youtube.com/watch?v=86GZC56rQCc)
- Unit 3: “Generative AI Agents”
- Listen to the summary [podcast episode](https://notifications.googleapis.com/email/redirect?t=AFG8qyWNs5SkJRJgl1KX7BEohHZq0PQD3dus7j-N5KtFjbPL1fmgSYeAMvUjPoO5U084nlngmcvPhX3wW_mqDVTr-YRD8TVgD6UCP8UJK1jv9BvxJIq8m-dlrRuPSoFeByfXL0Hd3NBGWougnudCkd_HQQqH_lgTat7m673VDPBIKdUvQ8uTM6CJT80bdgjcoOKLG9rOgybakPy8X9F60rD0YNCfIsrEM24E5Qg8CNzpppcIcnZU0u26xQD5M8cw&r=eJzLKCkpKLbS16_MLy0p1UtK1fcwSY9KMXFyLnIJBACSrQmX&s=ALHZ2r6SF1_NUSLUUj8Gvi2G-zBI) for this unit
- Read the [[huang_2024|Agents (Artificial Intelligence)]] whitepaper.
- Complete these code labs on Kaggle:
1. [Talk](https://www.kaggle.com/code/erikanderson1/day-3-function-calling-with-the-gemini-api/edit) to a database with function calling
2. [Build](https://www.kaggle.com/code/erikanderson1/day-3-building-an-agent-with-langgraph/edit) an agentic ordering system in LangGraph
- Watch the [Day 3 Livestream](https://www.youtube.com/watch?v=HQUtMWoTAD4&list=PLqFaTIg4myu-b1PlxitQdY0UYIbys-2es&index=5)
- Unit 4: "Domain-Specific LLMs"
- Listen to the summary [podcast episode](https://notifications.googleapis.com/email/redirect?t=AFG8qyWyqEXEKJnqhgo0lfAcFd1nRc2UI5R7EBmJOyK4D74GzqDtGe3G_TjRCNNNbEVeCFQNmXUTomrrm_h1TkVk62Yabme0-MxkZm6Dw71jvkJ0Eg0Ik4sPmyyWaGK_BIjHG_6DpC8_-xQjgSydKnbTgc1aPpPYtTPgeZy9lAzW2MzGlpO7g9bipTdYfb8TzfjcrT95V2ywsHhq5S0CfR04mYbSWS-ONF80uTrzxyPKOw8bCXUcNWKMZpcQKLDV&r=eJzLKCkpKLbS16_MLy0p1UtK1U8yTDSJ8g-0iEjxBACSrAmV&s=ALHZ2r5bM79UY3aFs5UnKpaZj_EK) for this unit
- Read the [[semturs_2024|Solving Domain-Specific Problems Using LLMs]] whitepaper
- Complete these code labs on Kaggle:
1. [Use](https://www.kaggle.com/code/erikanderson1/day-4-google-search-grounding/edit) Google Search data in generation (requires paid license to run API)
2. [Tune](https://www.kaggle.com/code/erikanderson1/day-4-fine-tuning-a-custom-model/edit) a Gemini model for a custom task
- Watch the [Day 4 Livestream](https://www.youtube.com/watch?v=odvuLMJWUSU&list=PLqFaTIg4myu-b1PlxitQdY0UYIbys-2es&index=4)
- Unit 5: “MLOps for Generative AI”
- Listen to the summary [podcast episode](https://notifications.googleapis.com/email/redirect?t=AFG8qyWQR3kSCVDBRMUG9Sd2QbVARqKDojaIgBGqAsYwPzxxUvl7eucEwpk-qYbg2OjcLrBWi00sLxXkDP06nH5v6wUD3xWvVnhn2NXJTdhMajqYUWLS15EtoWz43-Hbihw2ZTiWobOIUYwiNG6ZuqxyH5eW7VnWRzY0R25wWAtcpgrOb-og7-egJi8Fjj8KLzaG0GemvX5DtXUeLkMAOLm3QV1DSMrmcRBzZZF57AyKTh0v0ftPvHSk4C8fn2uh&r=eJzLKCkpKLbS16_MLy0p1UtK1c-2DDbzzMgMDc0yAQCT6gnF&s=ALHZ2r6xTRdW8qKZ0tLl8jPjMp8Z) for this unit
- Read the [[nawalgaria_2024a|Operationalizing Generative AI on Vertex AI using MLOps]] whitepaper.
- No code lab for today! We will do a code walkthrough and live demo of [goo.gle/e2e-gen-ai-app-starter-pack](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/gemini/sample-apps/e2e-gen-ai-app-starter-pack), a resource created for making MLOps for Gen AI easier and accelerating the path to production. Please go through the repository in advance.
- Watch the [Day 5 Livestream](https://www.youtube.com/watch?v=uCFW0i9xrBc&list=PLqFaTIg4myu-b1PlxitQdY0UYIbys-2es&index=5&t=2s)
## Notes
LLMs can be adapted through fine-tuning, which requires significantly fewer resources than training from scratch. They can further be nudged towards desired performance by prompt engineering.
Steven Johnson: lead for NotebookLM. What would it be like to create a note taking system that has an LLM at its center?
- Calls RAG "Source Grounding"
Wes Dyer: AIDA Engineering Lead, developing Data Science Agents
Check out the features and quota differences here: https://ai.google.dev/pricing
MLOps includes
1. Discovery
2. Development & Experimentation
3. Data Engineering
4. Deployment
5. Continuous monitoring
6. Continuous Improvement
7. (Optional) Continuous training
8. Governance
Zero-shot learning
One-shot learning
Few-shot learning
Parameter-Efficient Fine-Tuning (PEFT)
Model Chaining
Supervised Fine Tuning
Chain-of-Thought Prompting
Tree Prompting
Retrieval Augmented Generation
Source grounding
Foundation models
Model pre-training
Model evaluation
Alignment
Prompt engineering
Latency & Throughput
Supervised Fine-tuning
Reinforcement Learning from Human Feedback (RLHF)
Model quantization: reducing model weights and activations from 32-bit floats to 8-bit integers
Synthetic data
Vector database
Embeddings
Temperature
Top-K
Top-p
### Inspiration
For some inspiration, you might enjoy exploring some apps that have been built using the Gemini family of models. Here are a few that we like, and we think you will too.
- [TextFX](https://textfx.withgoogle.com/) is a suite of AI-powered tools for rappers, made in collaboration with Lupe Fiasco,
- [SQL Talk](https://sql-talk-r5gdynozbq-uc.a.run.app/) shows how you can talk directly to a database using the Gemini API,
- [NotebookLM](https://notebooklm.google/) uses Gemini models to build your own personal AI research assistant.
### Set up
All of the exercises in this notebook will use the [Gemini API](https://ai.google.dev/gemini-api/) by way of the [Python SDK](https://pypi.org/project/google-generativeai/).
```python
#!pip install -U -q "google-generativeai>=0.8.3"
import google.generativeai as genai
from IPython.display import HTML, Markdown, display
# Configure API (use the right approach, here we use Kaggle)
from kaggle_secrets import UserSecretsClient
GOOGLE_API_KEY = UserSecretsClient().get_secret("GOOGLE_API_KEY")
genai.configure(api_key=GOOGLE_API_KEY)
flash = genai.GenerativeModel('gemini-1.5-flash')
response = flash.generate_content("Explain AI to me like I'm a kid.")
print(response.text)
```
Find available models.
```python
for model in genai.list_models():
print(model.name)
```
```sdout
models/chat-bison-001
models/text-bison-001
models/embedding-gecko-001
models/gemini-1.0-pro-latest
models/gemini-1.0-pro
models/gemini-pro
models/gemini-1.0-pro-001
models/gemini-1.0-pro-vision-latest
models/gemini-pro-vision
models/gemini-1.5-pro-latest
models/gemini-1.5-pro-001
models/gemini-1.5-pro-002
models/gemini-1.5-pro
models/gemini-1.5-pro-exp-0801
models/gemini-1.5-pro-exp-0827
models/gemini-1.5-flash-latest
models/gemini-1.5-flash-001
models/gemini-1.5-flash-001-tuning
models/gemini-1.5-flash
models/gemini-1.5-flash-exp-0827
models/gemini-1.5-flash-002
models/gemini-1.5-flash-8b
models/gemini-1.5-flash-8b-001
models/gemini-1.5-flash-8b-latest
models/gemini-1.5-flash-8b-exp-0827
models/gemini-1.5-flash-8b-exp-0924
models/embedding-001
models/text-embedding-004
models/aqa
```
Pick a model and print the model information.
```python
for model in genai.list_models():
if model.name == 'models/gemini-1.5-flash':
print(model)
break
```
```sdout
Model(name='models/gemini-1.5-flash',
base_model_id='',
version='001',
display_name='Gemini 1.5 Flash',
description='Fast and versatile multimodal model for scaling across diverse tasks',
input_token_limit=1000000,
output_token_limit=8192,
supported_generation_methods=['generateContent', 'countTokens'],
temperature=1.0,
max_temperature=2.0,
top_p=0.95,
top_k=40)
```
### Tune hyperparameters
When generating text with an LLM, the output length affects cost and performance. Generating more tokens increases computation, leading to higher energy consumption, latency, and cost.
To stop the model from generating tokens past a limit, you can specify the `max_output_tokens` parameter when using the Gemini API. Specifying this parameter does not influence the generation of the output tokens, so the output will not become more stylistically or textually succinct, but it will stop generating tokens once the specified length is reached. Prompt engineering may be required to generate a more complete output for your given limit.
```python
short_model = genai.GenerativeModel(
'gemini-1.5-flash',
generation_config=genai.GenerationConfig(max_output_tokens=200))
```
Temperature controls the degree of randomness in token selection. Temperature doesn't provide any guarantees of randomness, but it can be used to "nudge" the output somewhat.
```python
from google.api_core import retry
high_temp_model = genai.GenerativeModel(
'gemini-1.5-flash',
generation_config=genai.GenerationConfig(temperature=2.0))
# When running lots of queries, it's a good practice to use a retry policy so your code
# automatically retries when hitting Resource Exhausted (quota limit) errors.
retry_policy = {
"retry": retry.Retry(predicate=retry.if_transient_error, initial=10, multiplier=1.5, timeout=300)
}
for _ in range(5):
response = high_temp_model.generate_content('Pick a random colour... (respond in a single word)',
request_options=retry_policy)
if response.parts:
print(response.text, '-' * 25)
```
Like temperature, top-K and top-P parameters are also used to control the diversity of the model's output.
Top-K is a positive integer that defines the number of most probable tokens from which to select the output token. A top-K of 1 selects a single token, performing greedy decoding.
Top-P defines the probability threshold that, once cumulatively exceeded, tokens stop being selected as candidates. A top-P of 0 is typically equivalent to greedy decoding, and a top-P of 1 typically selects every token in the model's vocabulary.
When both are supplied, the Gemini API will filter top-K tokens first, then top-P and then finally sample from the candidate tokens using the supplied temperature.
```python
model = genai.GenerativeModel(
'gemini-1.5-flash-001',
generation_config=genai.GenerationConfig(
# These are the default values for gemini-1.5-flash-001.
temperature=1.0,
top_k=64,
top_p=0.95,
))
```
### Prompt engineering
The models are trained to generate text, and can sometimes produce more text than you may wish for. The Gemini API has an [Enum mode](https://github.com/google-gemini/cookbook/blob/main/quickstarts/Enum.ipynb) feature that allows you to constrain the output to a fixed set of values.
```python
import enum
class Sentiment(enum.Enum):
POSITIVE = "positive"
NEUTRAL = "neutral"
NEGATIVE = "negative"
model = genai.GenerativeModel(
'gemini-1.5-flash-001',
generation_config=genai.GenerationConfig(
response_mime_type="text/x.enum",
response_schema=Sentiment
))
```
To provide control over the schema, and to ensure that you only receive JSON (with no other text or markdown), you can use the Gemini API's [JSON mode](https://github.com/google-gemini/cookbook/blob/main/quickstarts/JSON_mode.ipynb). This forces the model to constrain decoding, such that token selection is guided by the supplied schema.
```python
import typing_extensions as typing
class PizzaOrder(typing.TypedDict):
size: str
ingredients: list[str]
type: str
model = genai.GenerativeModel(
'gemini-1.5-flash-latest',
generation_config=genai.GenerationConfig(
temperature=0.1,
response_mime_type="application/json",
response_schema=PizzaOrder,
))
```
You can explore the steps taken by the model to understand its reasoning.
```python
model = genai.GenerativeModel(
'gemini-1.5-flash-latest',
tools='code_execution',)
code_exec_prompt = """
Calculate the sum of the first 14 prime numbers. Only consider the odd primes, and make sure you count them all.
"""
response = model.generate_content(code_exec_prompt, request_options=retry_policy)
for part in response.candidates[0].content.parts:
print(part)
print("-----")
```
The Gemini family of models can explain code to you too.
```python
file_contents = !curl https://raw.githubusercontent.com/magicmonty/bash-git-prompt/refs/heads/master/gitprompt.sh
explain_prompt = f"""
Please explain what this file does at a very high level. What is it, and why would I use it?"""
model = genai.GenerativeModel('gemini-1.5-flash-latest')
response = model.generate_content(explain_prompt, request_options=retry_policy)
Markdown(response.text)
```
### Reason and Act Prompt
ReAct (Reason and Act) is a prompt engineering approach that interleaves thoughts, actions and observations to help an LLM work through a reasoning problem step by step. See the [Searching Wikipedia with ReAct](https://github.com/google-gemini/cookbook/blob/main/examples/Search_Wikipedia_using_ReAct.ipynb) cookbook example.
## Retrieval Augmented Generation (RAG)
Two big limitations of LLMs are 1) that they only "know" the information that they were trained on, and 2) that they have limited input context windows. A way to address both of these limitations is to use a technique called Retrieval Augmented Generation, or RAG. A RAG system has three stages:
1. Indexing
2. Retrieval
3. Generation
First, install ChromaDB and the Gemini API Python SDK (see above for steps; in Kaggle, make sure to check the box next to an existing secret to apply it to a new notebook).
Use the following code to discover embeddings models. `text-embedding-004` is the most recent.
```python
for m in genai.list_models():
if "embedContent" in m.supported_generation_methods:
print(m.name)
```
```sdout
models/embedding-001
models/text-embedding-004
```
Create a small set of documents for an embedding database.
```python
DOCUMENT1 = "Operating the Climate Control System Your Googlecar has a climate control system that allows you to adjust the temperature and airflow in the car. To operate the climate control system, use the buttons and knobs located on the center console. Temperature: The temperature knob controls the temperature inside the car. Turn the knob clockwise to increase the temperature or counterclockwise to decrease the temperature. Airflow: The airflow knob controls the amount of airflow inside the car. Turn the knob clockwise to increase the airflow or counterclockwise to decrease the airflow. Fan speed: The fan speed knob controls the speed of the fan. Turn the knob clockwise to increase the fan speed or counterclockwise to decrease the fan speed. Mode: The mode button allows you to select the desired mode. The available modes are: Auto: The car will automatically adjust the temperature and airflow to maintain a comfortable level. Cool: The car will blow cool air into the car. Heat: The car will blow warm air into the car. Defrost: The car will blow warm air onto the windshield to defrost it."
DOCUMENT2 = 'Your Googlecar has a large touchscreen display that provides access to a variety of features, including navigation, entertainment, and climate control. To use the touchscreen display, simply touch the desired icon. For example, you can touch the "Navigation" icon to get directions to your destination or touch the "Music" icon to play your favorite songs.'
DOCUMENT3 = "Shifting Gears Your Googlecar has an automatic transmission. To shift gears, simply move the shift lever to the desired position. Park: This position is used when you are parked. The wheels are locked and the car cannot move. Reverse: This position is used to back up. Neutral: This position is used when you are stopped at a light or in traffic. The car is not in gear and will not move unless you press the gas pedal. Drive: This position is used to drive forward. Low: This position is used for driving in snow or other slippery conditions."
documents = [DOCUMENT1, DOCUMENT2, DOCUMENT3]
```
Create a [custom function](https://docs.trychroma.com/guides/embeddings#custom-embedding-functions) to generate embeddings with the Gemini API. In this task, you are implementing a retrieval system, so the `task_type` for generating the _document_ embeddings is `retrieval_document`. Later, you will use `retrieval_query` for the _query_ embeddings. Check out the [API reference](https://ai.google.dev/api/embeddings#v1beta.TaskType) for the full list of supported tasks.
```python
from chromadb import Documents, EmbeddingFunction, Embeddings
from google.api_core import retry
class GeminiEmbeddingFunction(EmbeddingFunction):
# Specify whether to generate embeddings for documents, or queries
document_mode = True
def __call__(self, input: Documents) -> Embeddings:
if self.document_mode:
embedding_task = "retrieval_document"
else:
embedding_task = "retrieval_query"
retry_policy = {"retry": retry.Retry(predicate=retry.if_transient_error)}
response = genai.embed_content(
model="models/text-embedding-004",
content=input,
task_type=embedding_task,
request_options=retry_policy,
)
return response["embedding"]
```
Now create a [Chroma database client](https://docs.trychroma.com/getting-started) that uses the `GeminiEmbeddingFunction` and populate the database with the documents you defined above.
```python
import chromadb
DB_NAME = "googlecardb"
embed_fn = GeminiEmbeddingFunction()
embed_fn.document_mode = True
chroma_client = chromadb.Client()
db = chroma_client.get_or_create_collection(name=DB_NAME, embedding_function=embed_fn)
db.add(documents=documents, ids=[str(i) for i in range(len(documents))])
db.count()
# You can peek at the data too.
db.peek(1)
```
```json
{'ids': ['0'],
'embeddings': array([[ 1.89996641e-02, ...]]),
'documents': ...,
'uris': None,
'data': None,
'metadatas': [None],
'included': [<IncludeEnum.embeddings: 'embeddings'>,
<IncludeEnum.documents: 'documents'>,
<IncludeEnum.metadatas: 'metadatas'>]}}
```
To search the Chroma database, call the `query` method. Note that you also switch to the `retrieval_query` mode of embedding generation.
```python
# Switch to query mode when generating embeddings.
embed_fn.document_mode = False
# Search the Chroma DB using the specified query.
query = "How do you use the touchscreen to play music?"
result = db.query(query_texts=[query], n_results=1)
[[passage]] = result["documents"]
Markdown(passage)
```
```markdown
Your Googlecar has a large touchscreen display that provides access to a variety of features, including navigation, entertainment, and climate control. To use the touchscreen display, simply touch the desired icon. For example, you can touch the "Navigation" icon to get directions to your destination or touch the "Music" icon to play your favorite songs.
```
Now that you have found a relevant passage from the set of documents (the _retrieval_ step), you can now assemble a generation prompt to have the Gemini API _generate_ a final answer. Note that in this example only a single passage was retrieved. In practice, especially when the size of your underlying data is large, you will want to retrieve more than one result and let the Gemini model determine what passages are relevant in answering the question. For this reason it's OK if some retrieved passages are not directly related to the question - this generation step should ignore them.
```python
passage_oneline = passage.replace("\n", " ")
query_oneline = query.replace("\n", " ")
# This prompt is where you can specify any guidance on tone, or what topics the model should stick to, or avoid.
prompt = f"""You are a helpful and informative bot that answers questions using text from the reference passage included below.
Be sure to respond in a complete sentence, being comprehensive, including all relevant background information.
However, you are talking to a non-technical audience, so be sure to break down complicated concepts and
strike a friendly and converstional tone. If the passage is irrelevant to the answer, you may ignore it.
QUESTION: {query_oneline}
PASSAGE: {passage_oneline}
"""
model = genai.GenerativeModel("gemini-1.5-flash-latest")
answer = model.generate_content(prompt)
Markdown(answer.text)
```
```markdown
To play music, simply touch the "Music" icon on the touchscreen display.
```
To learn more about using embeddings in the Gemini API, check out the [Intro to embeddings](https://ai.google.dev/gemini-api/docs/embeddings) or to learn more fundamentals, study the [embeddings chapter](https://developers.google.com/machine-learning/crash-course/embeddings) of the Machine Learning Crash Course.
For a hosted RAG system, check out the [Semantic Retrieval service](https://ai.google.dev/gemini-api/docs/semantic_retrieval) in the Gemini API. You can implement question-answering on your own documents in a single request, or host a database for even faster responses.
## Similarity scores
A similarity score of two embedding vectors can be obtained by calculating their inner product. If $u$ is the first embedding vector, and $v$ the second, this is $u^Tv$. As these embedding vectors are normalized to unit length, this is also the cosine similarity.
This score can be computed across all embeddings through the matrix self-multiplication: `df @ df.T`.
Note that the range from 0.0 (completely dissimilar) to 1.0 (completely similar) is depicted in the heatmap from dark (0.0) to light (1.0).
In the example below, we use the embeddings to calculate similarity scores, so the `task_type` for these embeddings is `semantic_similarity`. Check out the [API reference](https://ai.google.dev/api/embeddings#v1beta.TaskType) for the full list of tasks.
```python
texts = [
'The quick brown fox jumps over the lazy dog.',
'The quick rbown fox jumps over the lazy dog.',
'teh fast fox jumps over the slow woofer.',
'a quick brown fox jmps over lazy dog.',
'brown fox jumping over dog',
'fox > dog',
# Alternative pangram for comparison:
'The five boxing wizards jump quickly.',
# Unrelated text, also for comparison:
'Lorem ipsum dolor sit amet, consectetur adipiscing elit. Vivamus et hendrerit massa. Sed pulvinar, nisi a lobortis sagittis, neque risus gravida dolor, in porta dui odio vel purus.',
]
response = genai.embed_content(model='models/text-embedding-004',
content=texts,
task_type='semantic_similarity')
import pandas as pd
import seaborn as sns
# Helper function to truncate text for visualization
def truncate(t: str, limit: int = 50) -> str:
"""Truncate labels to fit on the chart."""
if len(t) > limit:
return t[:limit-3] + '...'
else:
return t
truncated_texts = [truncate(t) for t in texts]
# Set up the embeddings in a dataframe.
df = pd.DataFrame(response['embedding'], index=truncated_texts)
# Perform the similarity calculation
sim = df @ df.T
# Draw!
sns.heatmap(sim, vmin=0, vmax=1);
sim['The quick brown fox jumps over the lazy dog.'].sort_values(ascending=False)
```
```sdout
The quick brown fox jumps over the lazy dog. 0.999999
The quick rbown fox jumps over the lazy dog. 0.975623
a quick brown fox jmps over lazy dog. 0.939730
brown fox jumping over dog 0.894507
teh fast fox jumps over the slow woofer. 0.842152
fox > dog 0.776455
The five boxing wizards jump quickly. 0.635346
Lorem ipsum dolor sit amet, consectetur adipisc... 0.472174
Name: The quick brown fox jumps over the lazy dog., dtype: float64
```
**Further reading**
- Explore [search re-ranking using embeddings](https://github.com/google-gemini/cookbook/blob/main/examples/Search_reranking_using_embeddings.ipynb) with the Wikipedia API
- Perform [anomaly detection using embeddings](https://github.com/google-gemini/cookbook/blob/main/examples/Anomaly_detection_with_embeddings.ipynb)
## Classifying embeddings with Keras and the Gemini API
In this notebook, you'll learn to use the embeddings produced by the Gemini API to train a model that can classify newsgroup posts into the categories (the newsgroup itself) from the post contents.
This technique uses the Gemini API's embeddings as input, avoiding the need to train on text input directly, and as a result it is able to perform quite well using relatively few examples compared to training a text model from scratch.
```python
from sklearn.datasets import fetch_20newsgroups
newsgroups_train = fetch_20newsgroups(subset="train")
newsgroups_test = fetch_20newsgroups(subset="test")
```
Start by preprocessing the data for this tutorial in a Pandas dataframe. To remove any sensitive information like names and email addresses, you will take only the subject and body of each message. This is an optional step that transforms the input data into more generic text, rather than email posts, so that it will work in other contexts.
```python
import email
import re
import pandas as pd
def preprocess_newsgroup_row(data):
# Extract only the subject and body
msg = email.message_from_string(data)
text = f"{msg['Subject']}\n\n{msg.get_payload()}"
# Strip any remaining email addresses
text = re.sub(r"[\w\.-]+@[\w\.-]+", "", text)
# Truncate each entry to 5,000 characters
text = text[:5000]
return text
def preprocess_newsgroup_data(newsgroup_dataset):
# Put data points into dataframe
df = pd.DataFrame(
{"Text": newsgroup_dataset.data, "Label": newsgroup_dataset.target}
)
# Clean up the text
df["Text"] = df["Text"].apply(preprocess_newsgroup_row)
# Match label to target name index
df["Class Name"] = df["Label"].map(lambda l: newsgroup_dataset.target_names[l])
return df
# Apply preprocessing function to training and test datasets
df_train = preprocess_newsgroup_data(newsgroups_train)
df_test = preprocess_newsgroup_data(newsgroups_test)
```
Next, you will sample some of the data by taking 100 data points in the training dataset, and dropping a few of the categories to run through this tutorial. Choose the science categories to compare.
```python
def sample_data(df, num_samples, classes_to_keep):
# Sample rows, selecting num_samples of each Label.
df = (
df.groupby("Label")[df.columns]
.apply(lambda x: x.sample(num_samples))
.reset_index(drop=True)
)
df = df[df["Class Name"].str.contains(classes_to_keep)]
# We have fewer categories now, so re-calibrate the label encoding.
df["Class Name"] = df["Class Name"].astype("category")
df["Encoded Label"] = df["Class Name"].cat.codes
return df
TRAIN_NUM_SAMPLES = 100
TEST_NUM_SAMPLES = 25
CLASSES_TO_KEEP = "sci" # Class name should contain 'sci' to keep science categories
df_train = sample_data(df_train, TRAIN_NUM_SAMPLES, CLASSES_TO_KEEP)
df_test = sample_data(df_test, TEST_NUM_SAMPLES, CLASSES_TO_KEEP)
```
n this section, you will generate embeddings for each piece of text using the Gemini API embeddings endpoint. To learn more about embeddings, visit the [embeddings guide](https://ai.google.dev/docs/embeddings_guide).
The `text-embedding-004` model supports a task type parameter that generates embeddings tailored for the specific task.
Task Type | Description
--- | ---
RETRIEVAL_QUERY | Specifies the given text is a query in a search/retrieval setting.
RETRIEVAL_DOCUMENT | Specifies the given text is a document in a search/retrieval setting.
SEMANTIC_SIMILARITY | Specifies the given text will be used for Semantic Textual Similarity (STS).
CLASSIFICATION | Specifies that the embeddings will be used for classification.
CLUSTERING | Specifies that the embeddings will be used for clustering.
FACT_VERIFICATION | Specifies that the given text will be used for fact verification.
For this example you will be performing classification. Check out the [API reference](https://ai.google.dev/api/embeddings#v1beta.TaskType) for more on each supported task.
> [!Warning]
> This code is optimized for clarity, and is not particularly fast. It is left as an exercise for the reader to implement [batch](https://ai.google.dev/api/embeddings#method:-models.batchembedcontents) or parallel/asynchronous embedding generation. Running this step will take some time.
```python
from google.api_core import retry
from tqdm.rich import tqdm
tqdm.pandas()
@retry.Retry(timeout=300.0)
def embed_fn(text: str) -> list[float]:
# You will be performing classification, so set task_type accordingly.
response = genai.embed_content(
model="models/text-embedding-004", content=text, task_type="classification"
)
return response["embedding"]
def create_embeddings(df):
df["Embeddings"] = df["Text"].progress_apply(embed_fn)
return df
df_train = create_embeddings(df_train)
df_test = create_embeddings(df_test)
```
Here you will define a simple model that accepts the raw embedding data as input, has one hidden layer, and an output layer specifying the class probabilities. The prediction will correspond to the probability of a piece of text being a particular class of news.
When you run the model, Keras will take care of details like shuffling the data points, calculating metrics and other ML boilerplate.
```python
import keras
from keras import layers
def build_classification_model(input_size: int, num_classes: int) -> keras.Model:
return keras.Sequential(
[
layers.Input([input_size], name="embedding_inputs"),
layers.Dense(input_size, activation="relu", name="hidden"),
layers.Dense(num_classes, activation="softmax", name="output_probs"),
]
)
# Derive the embedding size from observing the data. The embedding size can also be specified
# with the `output_dimensionality` parameter to `embed_content` if you need to reduce it.
embedding_size = len(df_train["Embeddings"].iloc[0])
classifier = build_classification_model(
embedding_size, len(df_train["Class Name"].unique())
)
classifier.summary()
classifier.compile(
loss=keras.losses.SparseCategoricalCrossentropy(),
optimizer=keras.optimizers.Adam(learning_rate=0.001),
metrics=["accuracy"],
)
```
Finally, you can train your model. This code uses early stopping to exit the training loop once the loss value stabilises, so the number of epoch loops executed may differ from the specified value.
```python
import numpy as np
NUM_EPOCHS = 20
BATCH_SIZE = 32
# Split the x and y components of the train and validation subsets.
y_train = df_train["Encoded Label"]
x_train = np.stack(df_train["Embeddings"])
y_val = df_test["Encoded Label"]
x_val = np.stack(df_test["Embeddings"])
# Specify that it's OK to stop early if accuracy stabilises.
early_stop = keras.callbacks.EarlyStopping(monitor="accuracy", patience=3)
# Train the model for the desired number of epochs.
history = classifier.fit(
x=x_train,
y=y_train,
validation_data=(x_val, y_val),
callbacks=[early_stop],
batch_size=BATCH_SIZE,
epochs=NUM_EPOCHS,
)
```
Use Keras [`Model.evaluate`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#evaluate) to calculate the loss and accuracy on the test dataset.
```python
classifier.evaluate(x=x_val, y=y_val, return_dict=True)
```
To learn more about training models with Keras, including how to visualize the model training metrics, read [Training & evaluation with built-in methods](https://www.tensorflow.org/guide/keras/training_with_built_in_methods).
Now that you have a trained model with good evaluation metrics, you can try to make a prediction with new, hand-written data. Use the provided example or try your own data to see how the model performs.
```python
# This example avoids any space-specific terminology to see if the model avoids
# biases towards specific jargon.
new_text = """
First-timer looking to get out of here.
Hi, I'm writing about my interest in travelling to the outer limits!
What kind of craft can I buy? What is easiest to access from this 3rd rock?
Let me know how to do that please.
"""
embedded = embed_fn(new_text)
# Remember that the model takes embeddings as input, and the input must be batched,
# so here they are passed as a list to provide a batch of 1.
inp = np.array([embedded])
[result] = classifier.predict(inp)
for idx, category in enumerate(df_test["Class Name"].cat.categories):
print(f"{category}: {result[idx] * 100:0.2f}%")
```
```sdout
1/1 ------------------- 0s 52ms/step
sci.crypt: 0.33%
sci.electronics: 0.74%
sci.med: 0.37%
sci.space: 98.56%
```