Spring AI

Spring AI

简介

类似于 LangChain,Spring 也提供了和大模型的相关库。目前主要支持文本对话和从文本生成图像。但是对于向量数据库的支持比较好。

使用方式

Ollama Chat

Spring Initializer 里可以引入如下内容:

  • Ollama
  • Spring Web
  • Spring Reactive Web

可以得到如下样例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
repositories {
mavenCentral()
maven { url 'https://repo.spring.io/milestone' }
}

ext {
set('springAiVersion', "0.8.1")
}

dependencies {
implementation 'org.springframework.boot:spring-boot-starter-web'
implementation 'org.springframework.boot:spring-boot-starter-webflux'
implementation 'org.springframework.ai:spring-ai-transformers-spring-boot-starter'
implementation 'org.springframework.ai:spring-ai-ollama-spring-boot-starter'
compileOnly 'org.projectlombok:lombok'
developmentOnly 'org.springframework.boot:spring-boot-devtools'
annotationProcessor 'org.projectlombok:lombok'
testImplementation 'org.springframework.boot:spring-boot-starter-test'
testImplementation 'io.projectreactor:reactor-test'
}

dependencyManagement {
imports {
mavenBom "org.springframework.ai:spring-ai-bom:${springAiVersion}"
}
}

之后编写如下接口即可:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux;

import java.util.List;

@RestController
@RequestMapping("/ollama")
public class ChatController {

private final ChatClient chatClient;
private final EmbeddingModel embeddingModel;

public ChatController(ChatClient.Builder builder, EmbeddingModel embeddingModel) {
this.chatClient = builder.defaultOptions(OllamaOptions.create().withModel("llama3")).build();
this.embeddingModel = embeddingModel;
}

@GetMapping("/chat")
public String simple(@RequestParam(required = false, defaultValue = "hello") String message) {
return chatClient.prompt().user(message).call().content();
}

@GetMapping("/embedding")
public EmbeddingResponse embedding(@RequestParam(required = false, defaultValue = "hello") String message) {
return this.embeddingModel.embedForResponse(List.of(message));
}

@GetMapping("/chat/stream")
public Flux<String> simpleFlux(@RequestParam(required = false, defaultValue = "hello") String message) {
return chatClient.prompt().user(message).stream().content();
}

@GetMapping("/chat/parser")
public List<Song> simpleParser(@RequestParam(required = false, defaultValue = "Taylor Swift") String artist) {
String question = """
Please give me a list of top 10 songs and it's release year for the artist {artist}. If you don't know the answer , just say "I don't know".
""";
return chatClient.prompt().user(u -> u.text(question).param("artist", artist)).call().entity(new ParameterizedTypeReference<>() {
});
}
}

然后需要进行如下配置:

1
2
3
4
5
6
7
8
9
10
11
12
spring:
application:
name: xxx
ai:
ollama:
base-url: http://xxx.xxx.xxx.xxx:11434
chat:
options:
model: llama3
embedding:
options:
model: nomic-embed-text

注:此处返回的结果与格式和模型有较大的关系,建议使用 ollama run llama3 先进行测试。

设置不同模型

如果需要为不同的接口使用不同的模型则可以使用如下代码:

1
2
3
4
5
6
7
ChatResponse response = chatClient.prompt(
new Prompt(
"Generate the names of 5 famous pirates.",
OllamaOptions.create()
.withModel("llama2")
.withTemperature(0.4)
)).call();
自定义数据源

如果想要使用自定义数据源则可以采用如下方式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.reader.TextReader;
import org.springframework.ai.transformer.splitter.TextSplitter;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.Resource;

import java.io.File;
import java.util.List;

@Slf4j
@Configuration
public class RagConfig {

@Value("./vectorstore.json")
private String vectorStorePath;

@Value("classpath:/docs/olympic-faq.txt")
private Resource faq;

@Bean
SimpleVectorStore simpleVectorStore(EmbeddingModel embeddingModel) {
var simpleVectorStore = new SimpleVectorStore(embeddingModel);
var vectorStoreFile = new File(vectorStorePath);
if (vectorStoreFile.exists()) {
log.info("Vector Store File Exists,");
simpleVectorStore.load(vectorStoreFile);
} else {
log.info("Vector Store File Does Not Exist, load documents");
TextReader textReader = new TextReader(faq);
textReader.getCustomMetadata().put("filename", "olympic-faq.txt");
List<Document> documents = textReader.get();
TextSplitter textSplitter = new TokenTextSplitter();
List<Document> splitDocuments = textSplitter.apply(documents);
simpleVectorStore.add(splitDocuments);
simpleVectorStore.save(vectorStoreFile);
}
return simpleVectorStore;
}
}

然后编写 RagController :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

@RestController
@RequestMapping("/ollama")
public class RagController {

private final ChatClient chatClient;

public RagController(ChatClient.Builder builder, VectorStore vectorStore) {
this.chatClient = builder.defaultAdvisors(new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults()))
.build();
}

@GetMapping("/chat/rag")
public String rag(@RequestParam(value = "message", defaultValue = "How many athletes compete in the Olympic Games Paris 2024") String message) {
return chatClient.prompt()
.user(message)
.call()
.content();
}
}

最后需要补充 resources/prompts/rag-prompt-template.st 提示词模板:

1
2
3
4
5
6
7
8
9
You are a helpful assistant, conversing with a user about the subjects contained in a set of documents.
Use the information from the DOCUMENTS section to provide accurate answers. If unsure or if the answer
isn't found in the DOCUMENTS section, simply state that you don't know the answer.

QUESTION:
{input}

DOCUMENTS:
{documents}

和问答资料库 resources/docs/olympic-faq.txt

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

Q: How to buy tickets for the Olympic Games Paris 2024?
A: Tickets for the Olympic Games Paris 2024 are available for spectators around the world only on the official ticketing website. To buy tickets, click here.

The Paris 2024 Hospitality program offers packages that include tickets for sporting events combined with exceptional services in the competition venues (boxes, lounges) or in the heart of the city (accommodation, transport options, gastronomy, tourist activities, etc.).

The Paris 2024 Hospitality program is delivered by the official Paris 2024 Hospitality provider, On Location.

For more information about the Paris 2024 Hospitality & Travel offers, click here.

Q: What is the official mascot of the Olympic Games Paris 2024?
A: The Olympic Games Paris 2024 mascot is Olympic Phryge. The mascot is based on the traditional small Phrygian hats for which they are shaped after.

The name and design were chosen as symbols of freedom and to represent allegorical figures of the French republic.

The Olympic Phryge is decked out in blue, white and red - the colours of France’s famed tricolour flag - with the golden Paris 2024 logo emblazoned across its chest.

Q: When and where are the next Olympic Games?
A: The Olympic Games Paris 2024 will take place in France from 26 July to 11 August.

Q: What sports are in the Olympic Games Paris 2024?
A: 3X3 Basketball
Archery
Artistic Gymnastics
Artistic Swimming
Athletics
Badminton
Basketball
Beach Volleyball
Boxing
Breaking
Canoe Slalom
Canoe Sprint
Cycling BMX Freestyle
Cycling BMX Racing
Cycling Mountain Bike
Cycling Road
Cycling Track
Diving
Equestrian
Fencing
Football
Golf
Handball
Hockey
Judo
Marathon Swimming
Modern Pentathlon
Rhythmic Gymnastics
Rowing
Rugby Sevens
Sailing
Shooting
Skateboarding
Sport Climbing
Surfing
Swimming
Table Tennis
Taekwondo
Tennis
Trampoline
Triathlon
Volleyball
Water Polo
Weightlifting
Wrestling

Q:Where to watch the Olympic Games Paris 2024?
A: In France, the 2024 Olympic Games will be broadcast by Warner Bros. Discovery (formerly Discovery Inc.) via Eurosport, with free-to-air coverage sub-licensed to the country's public broadcaster France Télévisions. For a detailed list of the Paris 2024 Media Rights Holders here.

Q: How many athletes compete in the Olympic Games Paris 2024?
A: Around 10,500 athletes from 206 NOCs will compete.


Q: How often are the modern Olympic Games held?
A: The summer edition of the Olympic Games is normally held every four years.

Q: Where will the 2028 and 2032 Olympic Games be held?
A: Los Angeles, USA, will host the next Olympic Games from 14 to 30 July 2028. Brisbane, Australia, will host the Games in 2032.

Q: What is the difference between the Olympic Summer Games and the Olympic Winter Games?
A: The summer edition of the Olympic Games is a multi-sport event normally held once every four years usually in July or August.

The Olympic Winter Games are also held every four years in the winter months of the host location and the multi-sports competitions are practised on snow and ice.

Both Games are organised by the International Olympic Committee.

Q: Which cities have hosted the Olympic Summer Games?
A: 1896 Athens
1900 Paris
1904 St. Louis
1908 London
1912 Stockholm
1920 Antwerp
1924 Paris
1928 Amsterdam
1932 Los Angeles
1936 Berlin
1948 London
1952 Helsinki
1956 Melbourne
1960 Rome
1964 Tokyo
1968 Mexico City
1972 Munich
1976 Montreal
1980 Moscow
1984 Los Angeles
1988 Seoul
1992 Barcelona
1996 Atlanta
2000 Sydney
2004 Athens
2008 Beijing
2012 London
2016 Rio de Janeiro
2020 Tokyo
2024 Paris

Q: What year did the Olympic Games start?
A: The inaugural Games took place in 1896 in Athens, Greece.

注:如果不配置 Ollama embedding options model 的话在初次启动时需要拉取 hugginface 和 github 当中的内容,启动时间较长且对网络环境要求很高。

对话记录

编写如下代码即可:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.memory.InMemoryChatMemory;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY;

@RestController
@RequestMapping("/ollama")
public class MemoryController {

private final ChatClient chatClient;

public MemoryController(ChatClient.Builder builder) {
this.chatClient = builder.defaultAdvisors(new MessageChatMemoryAdvisor(new InMemoryChatMemory()))
.build();
}

@GetMapping("/chat/memory")
public String rag(
@RequestParam(defaultValue = "Here is chat room 1") String message,
@RequestParam(defaultValue = "1") String conversionId) {
return chatClient.prompt()
.user(message)
.advisors(a -> a.param(CHAT_MEMORY_CONVERSATION_ID_KEY, conversionId))
.call()
.content();
}
}
对话日志

编写如下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

@RestController
@RequestMapping("/ollama")
public class LogController {

private final ChatClient chatClient;

public LogController(ChatClient.Builder builder) {
this.chatClient = builder.defaultAdvisors((new SimpleLoggerAdvisor()).build();
}

@GetMapping("/chat/log")
public String rag(
@RequestParam(defaultValue = "Hi") String message) {
return chatClient.prompt()
.user(message)
.call()
.content();
}
}

然后修改日志配置即可:

1
2
3
4
5
6
7
8
logging:
level:
org:
springframework:
ai:
chat:
client:
advisor: DEBUG

注:此处需要 Spring AI 的版本要大于 1.0.0-SNAPSHOT 。

参考资料

官方文档

spring-into-ai

Spring AI 1.0.0 M1 released


Spring AI
https://wangqian0306.github.io/2024/spring-ai/
作者
WangQian
发布于
2024年3月29日
许可协议