Compare commits
4 commits
1ce2658b55
...
bd3267da5b
Author | SHA1 | Date | |
---|---|---|---|
bd3267da5b | |||
aa041df240 | |||
79f448f340 | |||
c5b6e47da0 |
7 changed files with 75 additions and 51 deletions
|
@ -1,5 +1,6 @@
|
||||||
package de.hhhammer.dchat.bot;
|
package de.hhhammer.dchat.bot;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import com.zaxxer.hikari.HikariConfig;
|
import com.zaxxer.hikari.HikariConfig;
|
||||||
import com.zaxxer.hikari.HikariDataSource;
|
import com.zaxxer.hikari.HikariDataSource;
|
||||||
import de.hhhammer.dchat.bot.openai.ChatGPTService;
|
import de.hhhammer.dchat.bot.openai.ChatGPTService;
|
||||||
|
@ -32,7 +33,7 @@ public class App {
|
||||||
System.exit(1);
|
System.exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
var chatGPTService = new ChatGPTService(openaiApiKey, HttpClient.newHttpClient());
|
var chatGPTService = new ChatGPTService(openaiApiKey, HttpClient.newHttpClient(), new ObjectMapper());
|
||||||
|
|
||||||
var config = new HikariConfig();
|
var config = new HikariConfig();
|
||||||
config.setJdbcUrl(postgresUrl);
|
config.setJdbcUrl(postgresUrl);
|
||||||
|
|
|
@ -21,7 +21,7 @@ public class MessageCreateHandler implements MessageCreateListener {
|
||||||
@Override
|
@Override
|
||||||
public void onMessageCreate(MessageCreateEvent event) {
|
public void onMessageCreate(MessageCreateEvent event) {
|
||||||
Thread.ofVirtual().start(() -> {
|
Thread.ofVirtual().start(() -> {
|
||||||
if (!event.canYouReadContent() || event.getMessageAuthor().isBotUser() || !(event.getMessage().getType() == MessageType.NORMAL || event.getMessage().getType() == MessageType.REPLY)) {
|
if (!event.canYouReadContent() || event.getMessageAuthor().isBotUser() || !isNormalOrReplyMessageType(event)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (!this.messageHandler.canHandle(event)) {
|
if (!this.messageHandler.canHandle(event)) {
|
||||||
|
@ -42,4 +42,9 @@ public class MessageCreateHandler implements MessageCreateListener {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private boolean isNormalOrReplyMessageType(MessageCreateEvent event) {
|
||||||
|
MessageType type = event.getMessage().getType();
|
||||||
|
return type == MessageType.NORMAL || type == MessageType.REPLY;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package de.hhhammer.dchat.bot.discord;
|
||||||
|
|
||||||
import de.hhhammer.dchat.bot.openai.ChatGPTRequestBuilder;
|
import de.hhhammer.dchat.bot.openai.ChatGPTRequestBuilder;
|
||||||
import de.hhhammer.dchat.bot.openai.ChatGPTService;
|
import de.hhhammer.dchat.bot.openai.ChatGPTService;
|
||||||
|
import de.hhhammer.dchat.bot.openai.MessageContext.ReplyInteraction;
|
||||||
import de.hhhammer.dchat.bot.openai.ResponseException;
|
import de.hhhammer.dchat.bot.openai.ResponseException;
|
||||||
import de.hhhammer.dchat.db.ServerDBService;
|
import de.hhhammer.dchat.db.ServerDBService;
|
||||||
import de.hhhammer.dchat.db.models.server.ServerMessage;
|
import de.hhhammer.dchat.db.models.server.ServerMessage;
|
||||||
|
@ -10,10 +11,12 @@ import org.javacord.api.entity.message.Message;
|
||||||
import org.javacord.api.entity.message.MessageReference;
|
import org.javacord.api.entity.message.MessageReference;
|
||||||
import org.javacord.api.entity.message.MessageType;
|
import org.javacord.api.entity.message.MessageType;
|
||||||
import org.javacord.api.event.message.MessageCreateEvent;
|
import org.javacord.api.event.message.MessageCreateEvent;
|
||||||
|
import org.jetbrains.annotations.NotNull;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public class ServerMessageHandler implements MessageHandler {
|
public class ServerMessageHandler implements MessageHandler {
|
||||||
|
|
||||||
|
@ -31,16 +34,8 @@ public class ServerMessageHandler implements MessageHandler {
|
||||||
String content = extractContent(event);
|
String content = extractContent(event);
|
||||||
var serverId = event.getServer().get().getId();
|
var serverId = event.getServer().get().getId();
|
||||||
var systemMessage = this.serverDBService.getConfig(String.valueOf(serverId)).get().systemMessage();
|
var systemMessage = this.serverDBService.getConfig(String.valueOf(serverId)).get().systemMessage();
|
||||||
var request = event.getMessage().getType() == MessageType.REPLY ?
|
List<ReplyInteraction> messageContext = event.getMessage().getType() == MessageType.REPLY ? getContextMessages(event) : List.of();
|
||||||
new ChatGPTRequestBuilder()
|
var request = new ChatGPTRequestBuilder().contextRequest(messageContext, content, systemMessage);
|
||||||
.replyRequest(event.getMessage()
|
|
||||||
.getMessageReference()
|
|
||||||
.map(MessageReference::getMessage)
|
|
||||||
.flatMap(m -> m)
|
|
||||||
.map(Message::getReadableContent)
|
|
||||||
.stream().toList(),
|
|
||||||
content, systemMessage) :
|
|
||||||
new ChatGPTRequestBuilder().simpleRequest(content, systemMessage);
|
|
||||||
var response = this.chatGPTService.submit(request);
|
var response = this.chatGPTService.submit(request);
|
||||||
if (response.choices().isEmpty()) {
|
if (response.choices().isEmpty()) {
|
||||||
event.getMessage().reply("No response available");
|
event.getMessage().reply("No response available");
|
||||||
|
@ -97,4 +92,16 @@ public class ServerMessageHandler implements MessageHandler {
|
||||||
long ownId = event.getApi().getYourself().getId();
|
long ownId = event.getApi().getYourself().getId();
|
||||||
return event.getMessageContent().replaceFirst("<" + ownId + "> ", "");
|
return event.getMessageContent().replaceFirst("<" + ownId + "> ", "");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@NotNull
|
||||||
|
private List<ReplyInteraction> getContextMessages(MessageCreateEvent event) {
|
||||||
|
return event.getMessage()
|
||||||
|
.getMessageReference()
|
||||||
|
.map(MessageReference::getMessage)
|
||||||
|
.flatMap(m -> m)
|
||||||
|
.map(Message::getReadableContent)
|
||||||
|
.map(ReplyInteraction::new)
|
||||||
|
.stream()
|
||||||
|
.toList();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package de.hhhammer.dchat.bot.discord;
|
||||||
|
|
||||||
import de.hhhammer.dchat.bot.openai.ChatGPTRequestBuilder;
|
import de.hhhammer.dchat.bot.openai.ChatGPTRequestBuilder;
|
||||||
import de.hhhammer.dchat.bot.openai.ChatGPTService;
|
import de.hhhammer.dchat.bot.openai.ChatGPTService;
|
||||||
|
import de.hhhammer.dchat.bot.openai.MessageContext.PreviousInteraction;
|
||||||
import de.hhhammer.dchat.bot.openai.ResponseException;
|
import de.hhhammer.dchat.bot.openai.ResponseException;
|
||||||
import de.hhhammer.dchat.db.UserDBService;
|
import de.hhhammer.dchat.db.UserDBService;
|
||||||
import de.hhhammer.dchat.db.models.user.UserMessage;
|
import de.hhhammer.dchat.db.models.user.UserMessage;
|
||||||
|
@ -10,6 +11,7 @@ import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public class UserMessageHandler implements MessageHandler {
|
public class UserMessageHandler implements MessageHandler {
|
||||||
private static final Logger logger = LoggerFactory.getLogger(UserMessageHandler.class);
|
private static final Logger logger = LoggerFactory.getLogger(UserMessageHandler.class);
|
||||||
|
@ -27,9 +29,9 @@ public class UserMessageHandler implements MessageHandler {
|
||||||
var userId = String.valueOf(event.getMessageAuthor().getId());
|
var userId = String.valueOf(event.getMessageAuthor().getId());
|
||||||
var config = this.userDBService.getConfig(userId).get();
|
var config = this.userDBService.getConfig(userId).get();
|
||||||
var systemMessage = config.systemMessage();
|
var systemMessage = config.systemMessage();
|
||||||
var context = this.userDBService.getLastMessages(userId, config.contextLength())
|
List<PreviousInteraction> context = this.userDBService.getLastMessages(userId, config.contextLength())
|
||||||
.stream()
|
.stream()
|
||||||
.map(userMessage -> new ChatGPTRequestBuilder.PreviousInteraction(userMessage.question(), userMessage.answer()))
|
.map(userMessage -> new PreviousInteraction(userMessage.question(), userMessage.answer()))
|
||||||
.toList();
|
.toList();
|
||||||
var request = new ChatGPTRequestBuilder().contextRequest(context, content, systemMessage);
|
var request = new ChatGPTRequestBuilder().contextRequest(context, content, systemMessage);
|
||||||
var response = this.chatGPTService.submit(request);
|
var response = this.chatGPTService.submit(request);
|
||||||
|
|
|
@ -1,59 +1,59 @@
|
||||||
package de.hhhammer.dchat.bot.openai;
|
package de.hhhammer.dchat.bot.openai;
|
||||||
|
|
||||||
|
import de.hhhammer.dchat.bot.openai.MessageContext.PreviousInteraction;
|
||||||
|
import de.hhhammer.dchat.bot.openai.MessageContext.ReplyInteraction;
|
||||||
import de.hhhammer.dchat.bot.openai.models.ChatGPTRequest;
|
import de.hhhammer.dchat.bot.openai.models.ChatGPTRequest;
|
||||||
|
import org.jetbrains.annotations.NotNull;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class ChatGPTRequestBuilder {
|
public class ChatGPTRequestBuilder {
|
||||||
private final String smallContextModel = "gpt-3.5-turbo";
|
private static final String smallContextModel = "gpt-3.5-turbo";
|
||||||
private final String bigContextModel = "gpt-3.5-turbo-16k";
|
private static final String bigContextModel = "gpt-3.5-turbo-16k";
|
||||||
|
|
||||||
public ChatGPTRequestBuilder() {
|
public ChatGPTRequestBuilder() {
|
||||||
}
|
}
|
||||||
|
|
||||||
public ChatGPTRequest simpleRequest(String content, String systemMessage) {
|
public ChatGPTRequest contextRequest(List<? extends MessageContext> contextMessages, String message, String systemMessage) {
|
||||||
|
List<ChatGPTRequest.Message> messages = getMessages(contextMessages, message, systemMessage);
|
||||||
|
String contextModel = contextMessages.size() <= 1 ? smallContextModel : bigContextModel;
|
||||||
return new ChatGPTRequest(
|
return new ChatGPTRequest(
|
||||||
smallContextModel,
|
contextModel,
|
||||||
List.of(new ChatGPTRequest.Message("system", systemMessage), new ChatGPTRequest.Message("user", content)),
|
|
||||||
0.7f
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
public ChatGPTRequest replyRequest(List<String> contextMessages, String message, String systemMessage) {
|
|
||||||
List<ChatGPTRequest.Message> messages = new ArrayList<>();
|
|
||||||
messages.add(new ChatGPTRequest.Message("system", systemMessage));
|
|
||||||
var context = contextMessages.stream()
|
|
||||||
.map(m -> new ChatGPTRequest.Message("assistant", m))
|
|
||||||
.toList();
|
|
||||||
messages.addAll(context);
|
|
||||||
messages.add(new ChatGPTRequest.Message("user", message));
|
|
||||||
return new ChatGPTRequest(
|
|
||||||
smallContextModel,
|
|
||||||
messages,
|
messages,
|
||||||
0.7f
|
0.7f
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
public ChatGPTRequest contextRequest(List<PreviousInteraction> contextMessages, String message, String systemMessage) {
|
@NotNull
|
||||||
|
private static List<ChatGPTRequest.Message> getMessages(List<? extends MessageContext> contextMessages, String message, String systemMessage) {
|
||||||
|
var systemMsg = new ChatGPTRequest.Message("system", systemMessage);
|
||||||
|
List<ChatGPTRequest.Message> contextMsgs = getContextMessages(contextMessages);
|
||||||
|
var userMessage = new ChatGPTRequest.Message("user", message);
|
||||||
List<ChatGPTRequest.Message> messages = new ArrayList<>();
|
List<ChatGPTRequest.Message> messages = new ArrayList<>();
|
||||||
messages.add(new ChatGPTRequest.Message("system", systemMessage));
|
messages.add(systemMsg);
|
||||||
var context = contextMessages.stream()
|
messages.addAll(contextMsgs);
|
||||||
.map(m -> List.of(
|
messages.add(userMessage);
|
||||||
new ChatGPTRequest.Message("user", m.question),
|
return messages;
|
||||||
new ChatGPTRequest.Message("assistant", m.answer)
|
}
|
||||||
))
|
|
||||||
|
@NotNull
|
||||||
|
private static List<ChatGPTRequest.Message> getContextMessages(List<? extends MessageContext> contextMessages) {
|
||||||
|
return contextMessages.stream()
|
||||||
|
.map(ChatGPTRequestBuilder::mapContextMessages)
|
||||||
.flatMap(List::stream)
|
.flatMap(List::stream)
|
||||||
.toList();
|
.toList();
|
||||||
messages.addAll(context);
|
|
||||||
messages.add(new ChatGPTRequest.Message("user", message));
|
|
||||||
return new ChatGPTRequest(
|
|
||||||
bigContextModel,
|
|
||||||
messages,
|
|
||||||
0.7f
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public record PreviousInteraction(String question, String answer) {
|
@NotNull
|
||||||
|
private static List<ChatGPTRequest.Message> mapContextMessages(MessageContext contextMessage) {
|
||||||
|
return switch (contextMessage) {
|
||||||
|
case PreviousInteraction previousInteractions -> List.of(
|
||||||
|
new ChatGPTRequest.Message("user", previousInteractions.question()),
|
||||||
|
new ChatGPTRequest.Message("assistant", previousInteractions.answer())
|
||||||
|
);
|
||||||
|
case ReplyInteraction replyInteractions ->
|
||||||
|
List.of(new ChatGPTRequest.Message("assistant", replyInteractions.answer()));
|
||||||
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,14 +13,14 @@ import java.time.Duration;
|
||||||
|
|
||||||
public class ChatGPTService {
|
public class ChatGPTService {
|
||||||
private static final String url = "https://api.openai.com/v1/chat/completions";
|
private static final String url = "https://api.openai.com/v1/chat/completions";
|
||||||
private final ObjectMapper mapper = new ObjectMapper();
|
|
||||||
private final String apiKey;
|
private final String apiKey;
|
||||||
private final HttpClient httpClient;
|
private final HttpClient httpClient;
|
||||||
|
private final ObjectMapper mapper;
|
||||||
|
|
||||||
|
public ChatGPTService(String apiKey, HttpClient httpClient, ObjectMapper mapper) {
|
||||||
public ChatGPTService(String apiKey, HttpClient httpClient) {
|
|
||||||
this.apiKey = apiKey;
|
this.apiKey = apiKey;
|
||||||
this.httpClient = httpClient;
|
this.httpClient = httpClient;
|
||||||
|
this.mapper = mapper;
|
||||||
}
|
}
|
||||||
|
|
||||||
public ChatGPTResponse submit(ChatGPTRequest chatGPTRequest) throws IOException, InterruptedException, ResponseException {
|
public ChatGPTResponse submit(ChatGPTRequest chatGPTRequest) throws IOException, InterruptedException, ResponseException {
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
package de.hhhammer.dchat.bot.openai;
|
||||||
|
|
||||||
|
public sealed interface MessageContext {
|
||||||
|
record ReplyInteraction(String answer) implements MessageContext {
|
||||||
|
}
|
||||||
|
|
||||||
|
record PreviousInteraction(String question, String answer) implements MessageContext {
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue