Compare commits

...

4 commits

Author SHA1 Message Date
bd3267da5b bot: Use constructor injection for the ObjectMapper
Some checks failed
ci/woodpecker/push/java/1 Pipeline was successful
ci/woodpecker/push/java/2 Pipeline was successful
ci/woodpecker/push/java/3 Pipeline was successful
ci/woodpecker/push/java/4 Pipeline was successful
ci/woodpecker/push/nodejs Pipeline was successful
ci/woodpecker/push/oci-image-build/1 Pipeline failed
ci/woodpecker/push/oci-image-build/2 Pipeline failed
ci/woodpecker/push/oci-image-build/3 Pipeline failed
ci/woodpecker/push/oci-image-build/4 Pipeline failed
2023-12-10 21:03:30 +01:00
aa041df240 bot: Refactor to use pattern matching for building the message context 2023-12-10 20:58:25 +01:00
79f448f340 bot: Remove differentiating between direct and reply request 2023-12-10 20:14:03 +01:00
c5b6e47da0 bot: Extract message type check into own method
Makes it easier to read.
2023-12-10 19:55:51 +01:00
7 changed files with 75 additions and 51 deletions

View file

@ -1,5 +1,6 @@
package de.hhhammer.dchat.bot; package de.hhhammer.dchat.bot;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.zaxxer.hikari.HikariConfig; import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource; import com.zaxxer.hikari.HikariDataSource;
import de.hhhammer.dchat.bot.openai.ChatGPTService; import de.hhhammer.dchat.bot.openai.ChatGPTService;
@ -32,7 +33,7 @@ public class App {
System.exit(1); System.exit(1);
} }
var chatGPTService = new ChatGPTService(openaiApiKey, HttpClient.newHttpClient()); var chatGPTService = new ChatGPTService(openaiApiKey, HttpClient.newHttpClient(), new ObjectMapper());
var config = new HikariConfig(); var config = new HikariConfig();
config.setJdbcUrl(postgresUrl); config.setJdbcUrl(postgresUrl);

View file

@ -21,7 +21,7 @@ public class MessageCreateHandler implements MessageCreateListener {
@Override @Override
public void onMessageCreate(MessageCreateEvent event) { public void onMessageCreate(MessageCreateEvent event) {
Thread.ofVirtual().start(() -> { Thread.ofVirtual().start(() -> {
if (!event.canYouReadContent() || event.getMessageAuthor().isBotUser() || !(event.getMessage().getType() == MessageType.NORMAL || event.getMessage().getType() == MessageType.REPLY)) { if (!event.canYouReadContent() || event.getMessageAuthor().isBotUser() || !isNormalOrReplyMessageType(event)) {
return; return;
} }
if (!this.messageHandler.canHandle(event)) { if (!this.messageHandler.canHandle(event)) {
@ -42,4 +42,9 @@ public class MessageCreateHandler implements MessageCreateListener {
} }
}); });
} }
private boolean isNormalOrReplyMessageType(MessageCreateEvent event) {
MessageType type = event.getMessage().getType();
return type == MessageType.NORMAL || type == MessageType.REPLY;
}
} }

View file

@ -2,6 +2,7 @@ package de.hhhammer.dchat.bot.discord;
import de.hhhammer.dchat.bot.openai.ChatGPTRequestBuilder; import de.hhhammer.dchat.bot.openai.ChatGPTRequestBuilder;
import de.hhhammer.dchat.bot.openai.ChatGPTService; import de.hhhammer.dchat.bot.openai.ChatGPTService;
import de.hhhammer.dchat.bot.openai.MessageContext.ReplyInteraction;
import de.hhhammer.dchat.bot.openai.ResponseException; import de.hhhammer.dchat.bot.openai.ResponseException;
import de.hhhammer.dchat.db.ServerDBService; import de.hhhammer.dchat.db.ServerDBService;
import de.hhhammer.dchat.db.models.server.ServerMessage; import de.hhhammer.dchat.db.models.server.ServerMessage;
@ -10,10 +11,12 @@ import org.javacord.api.entity.message.Message;
import org.javacord.api.entity.message.MessageReference; import org.javacord.api.entity.message.MessageReference;
import org.javacord.api.entity.message.MessageType; import org.javacord.api.entity.message.MessageType;
import org.javacord.api.event.message.MessageCreateEvent; import org.javacord.api.event.message.MessageCreateEvent;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.IOException; import java.io.IOException;
import java.util.List;
public class ServerMessageHandler implements MessageHandler { public class ServerMessageHandler implements MessageHandler {
@ -31,16 +34,8 @@ public class ServerMessageHandler implements MessageHandler {
String content = extractContent(event); String content = extractContent(event);
var serverId = event.getServer().get().getId(); var serverId = event.getServer().get().getId();
var systemMessage = this.serverDBService.getConfig(String.valueOf(serverId)).get().systemMessage(); var systemMessage = this.serverDBService.getConfig(String.valueOf(serverId)).get().systemMessage();
var request = event.getMessage().getType() == MessageType.REPLY ? List<ReplyInteraction> messageContext = event.getMessage().getType() == MessageType.REPLY ? getContextMessages(event) : List.of();
new ChatGPTRequestBuilder() var request = new ChatGPTRequestBuilder().contextRequest(messageContext, content, systemMessage);
.replyRequest(event.getMessage()
.getMessageReference()
.map(MessageReference::getMessage)
.flatMap(m -> m)
.map(Message::getReadableContent)
.stream().toList(),
content, systemMessage) :
new ChatGPTRequestBuilder().simpleRequest(content, systemMessage);
var response = this.chatGPTService.submit(request); var response = this.chatGPTService.submit(request);
if (response.choices().isEmpty()) { if (response.choices().isEmpty()) {
event.getMessage().reply("No response available"); event.getMessage().reply("No response available");
@ -97,4 +92,16 @@ public class ServerMessageHandler implements MessageHandler {
long ownId = event.getApi().getYourself().getId(); long ownId = event.getApi().getYourself().getId();
return event.getMessageContent().replaceFirst("<" + ownId + "> ", ""); return event.getMessageContent().replaceFirst("<" + ownId + "> ", "");
} }
@NotNull
private List<ReplyInteraction> getContextMessages(MessageCreateEvent event) {
return event.getMessage()
.getMessageReference()
.map(MessageReference::getMessage)
.flatMap(m -> m)
.map(Message::getReadableContent)
.map(ReplyInteraction::new)
.stream()
.toList();
}
} }

View file

@ -2,6 +2,7 @@ package de.hhhammer.dchat.bot.discord;
import de.hhhammer.dchat.bot.openai.ChatGPTRequestBuilder; import de.hhhammer.dchat.bot.openai.ChatGPTRequestBuilder;
import de.hhhammer.dchat.bot.openai.ChatGPTService; import de.hhhammer.dchat.bot.openai.ChatGPTService;
import de.hhhammer.dchat.bot.openai.MessageContext.PreviousInteraction;
import de.hhhammer.dchat.bot.openai.ResponseException; import de.hhhammer.dchat.bot.openai.ResponseException;
import de.hhhammer.dchat.db.UserDBService; import de.hhhammer.dchat.db.UserDBService;
import de.hhhammer.dchat.db.models.user.UserMessage; import de.hhhammer.dchat.db.models.user.UserMessage;
@ -10,6 +11,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.IOException; import java.io.IOException;
import java.util.List;
public class UserMessageHandler implements MessageHandler { public class UserMessageHandler implements MessageHandler {
private static final Logger logger = LoggerFactory.getLogger(UserMessageHandler.class); private static final Logger logger = LoggerFactory.getLogger(UserMessageHandler.class);
@ -27,9 +29,9 @@ public class UserMessageHandler implements MessageHandler {
var userId = String.valueOf(event.getMessageAuthor().getId()); var userId = String.valueOf(event.getMessageAuthor().getId());
var config = this.userDBService.getConfig(userId).get(); var config = this.userDBService.getConfig(userId).get();
var systemMessage = config.systemMessage(); var systemMessage = config.systemMessage();
var context = this.userDBService.getLastMessages(userId, config.contextLength()) List<PreviousInteraction> context = this.userDBService.getLastMessages(userId, config.contextLength())
.stream() .stream()
.map(userMessage -> new ChatGPTRequestBuilder.PreviousInteraction(userMessage.question(), userMessage.answer())) .map(userMessage -> new PreviousInteraction(userMessage.question(), userMessage.answer()))
.toList(); .toList();
var request = new ChatGPTRequestBuilder().contextRequest(context, content, systemMessage); var request = new ChatGPTRequestBuilder().contextRequest(context, content, systemMessage);
var response = this.chatGPTService.submit(request); var response = this.chatGPTService.submit(request);

View file

@ -1,59 +1,59 @@
package de.hhhammer.dchat.bot.openai; package de.hhhammer.dchat.bot.openai;
import de.hhhammer.dchat.bot.openai.MessageContext.PreviousInteraction;
import de.hhhammer.dchat.bot.openai.MessageContext.ReplyInteraction;
import de.hhhammer.dchat.bot.openai.models.ChatGPTRequest; import de.hhhammer.dchat.bot.openai.models.ChatGPTRequest;
import org.jetbrains.annotations.NotNull;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
public class ChatGPTRequestBuilder { public class ChatGPTRequestBuilder {
private final String smallContextModel = "gpt-3.5-turbo"; private static final String smallContextModel = "gpt-3.5-turbo";
private final String bigContextModel = "gpt-3.5-turbo-16k"; private static final String bigContextModel = "gpt-3.5-turbo-16k";
public ChatGPTRequestBuilder() { public ChatGPTRequestBuilder() {
} }
public ChatGPTRequest simpleRequest(String content, String systemMessage) { public ChatGPTRequest contextRequest(List<? extends MessageContext> contextMessages, String message, String systemMessage) {
List<ChatGPTRequest.Message> messages = getMessages(contextMessages, message, systemMessage);
String contextModel = contextMessages.size() <= 1 ? smallContextModel : bigContextModel;
return new ChatGPTRequest( return new ChatGPTRequest(
smallContextModel, contextModel,
List.of(new ChatGPTRequest.Message("system", systemMessage), new ChatGPTRequest.Message("user", content)),
0.7f
);
}
public ChatGPTRequest replyRequest(List<String> contextMessages, String message, String systemMessage) {
List<ChatGPTRequest.Message> messages = new ArrayList<>();
messages.add(new ChatGPTRequest.Message("system", systemMessage));
var context = contextMessages.stream()
.map(m -> new ChatGPTRequest.Message("assistant", m))
.toList();
messages.addAll(context);
messages.add(new ChatGPTRequest.Message("user", message));
return new ChatGPTRequest(
smallContextModel,
messages, messages,
0.7f 0.7f
); );
} }
public ChatGPTRequest contextRequest(List<PreviousInteraction> contextMessages, String message, String systemMessage) { @NotNull
private static List<ChatGPTRequest.Message> getMessages(List<? extends MessageContext> contextMessages, String message, String systemMessage) {
var systemMsg = new ChatGPTRequest.Message("system", systemMessage);
List<ChatGPTRequest.Message> contextMsgs = getContextMessages(contextMessages);
var userMessage = new ChatGPTRequest.Message("user", message);
List<ChatGPTRequest.Message> messages = new ArrayList<>(); List<ChatGPTRequest.Message> messages = new ArrayList<>();
messages.add(new ChatGPTRequest.Message("system", systemMessage)); messages.add(systemMsg);
var context = contextMessages.stream() messages.addAll(contextMsgs);
.map(m -> List.of( messages.add(userMessage);
new ChatGPTRequest.Message("user", m.question), return messages;
new ChatGPTRequest.Message("assistant", m.answer) }
))
@NotNull
private static List<ChatGPTRequest.Message> getContextMessages(List<? extends MessageContext> contextMessages) {
return contextMessages.stream()
.map(ChatGPTRequestBuilder::mapContextMessages)
.flatMap(List::stream) .flatMap(List::stream)
.toList(); .toList();
messages.addAll(context);
messages.add(new ChatGPTRequest.Message("user", message));
return new ChatGPTRequest(
bigContextModel,
messages,
0.7f
);
} }
public record PreviousInteraction(String question, String answer) { @NotNull
private static List<ChatGPTRequest.Message> mapContextMessages(MessageContext contextMessage) {
return switch (contextMessage) {
case PreviousInteraction previousInteractions -> List.of(
new ChatGPTRequest.Message("user", previousInteractions.question()),
new ChatGPTRequest.Message("assistant", previousInteractions.answer())
);
case ReplyInteraction replyInteractions ->
List.of(new ChatGPTRequest.Message("assistant", replyInteractions.answer()));
};
} }
} }

View file

@ -13,14 +13,14 @@ import java.time.Duration;
public class ChatGPTService { public class ChatGPTService {
private static final String url = "https://api.openai.com/v1/chat/completions"; private static final String url = "https://api.openai.com/v1/chat/completions";
private final ObjectMapper mapper = new ObjectMapper();
private final String apiKey; private final String apiKey;
private final HttpClient httpClient; private final HttpClient httpClient;
private final ObjectMapper mapper;
public ChatGPTService(String apiKey, HttpClient httpClient, ObjectMapper mapper) {
public ChatGPTService(String apiKey, HttpClient httpClient) {
this.apiKey = apiKey; this.apiKey = apiKey;
this.httpClient = httpClient; this.httpClient = httpClient;
this.mapper = mapper;
} }
public ChatGPTResponse submit(ChatGPTRequest chatGPTRequest) throws IOException, InterruptedException, ResponseException { public ChatGPTResponse submit(ChatGPTRequest chatGPTRequest) throws IOException, InterruptedException, ResponseException {

View file

@ -0,0 +1,9 @@
package de.hhhammer.dchat.bot.openai;
public sealed interface MessageContext {
record ReplyInteraction(String answer) implements MessageContext {
}
record PreviousInteraction(String question, String answer) implements MessageContext {
}
}