Compare commits
5 commits
5c67a47806
...
168a5d6c81
Author | SHA1 | Date | |
---|---|---|---|
168a5d6c81 | |||
b77836effb | |||
e58980cea3 | |||
ca95bb45cb | |||
ab48afd5ed |
24 changed files with 642 additions and 562 deletions
|
@ -4,47 +4,47 @@ import com.fasterxml.jackson.databind.ObjectMapper;
|
|||
import com.zaxxer.hikari.HikariConfig;
|
||||
import com.zaxxer.hikari.HikariDataSource;
|
||||
import de.hhhammer.dchat.bot.openai.ChatGPTService;
|
||||
import de.hhhammer.dchat.db.ServerDBService;
|
||||
import de.hhhammer.dchat.db.UserDBService;
|
||||
import de.hhhammer.dchat.db.PostgresServerDBService;
|
||||
import de.hhhammer.dchat.db.PostgresUserDBService;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.net.http.HttpClient;
|
||||
|
||||
public class App {
|
||||
public final class App {
|
||||
private static final Logger logger = LoggerFactory.getLogger(App.class);
|
||||
|
||||
public static void main(String[] args) {
|
||||
String discordApiKey = System.getenv("DISCORD_API_KEY");
|
||||
public static void main(final String[] args) {
|
||||
final String discordApiKey = System.getenv("DISCORD_API_KEY");
|
||||
if (discordApiKey == null) {
|
||||
logger.error("Missing environment variables: DISCORD_API_KEY");
|
||||
System.exit(1);
|
||||
}
|
||||
String openaiApiKey = System.getenv("OPENAI_API_KEY");
|
||||
final String openaiApiKey = System.getenv("OPENAI_API_KEY");
|
||||
if (openaiApiKey == null) {
|
||||
logger.error("Missing environment variables: OPENAI_API_KEY");
|
||||
System.exit(1);
|
||||
}
|
||||
String postgresUser = System.getenv("POSTGRES_USER");
|
||||
String postgresPassword = System.getenv("POSTGRES_PASSWORD");
|
||||
String postgresUrl = System.getenv("POSTGRES_URL");
|
||||
final String postgresUser = System.getenv("POSTGRES_USER");
|
||||
final String postgresPassword = System.getenv("POSTGRES_PASSWORD");
|
||||
final String postgresUrl = System.getenv("POSTGRES_URL");
|
||||
if (postgresUser == null || postgresPassword == null || postgresUrl == null) {
|
||||
logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL");
|
||||
System.exit(1);
|
||||
}
|
||||
|
||||
var chatGPTService = new ChatGPTService(openaiApiKey, HttpClient.newHttpClient(), new ObjectMapper());
|
||||
final var chatGPTService = new ChatGPTService(openaiApiKey, HttpClient.newHttpClient(), new ObjectMapper());
|
||||
|
||||
var config = new HikariConfig();
|
||||
final var config = new HikariConfig();
|
||||
config.setJdbcUrl(postgresUrl);
|
||||
config.setUsername(postgresUser);
|
||||
config.setPassword(postgresPassword);
|
||||
|
||||
try (var ds = new HikariDataSource(config)) {
|
||||
var serverDBService = new ServerDBService(ds);
|
||||
var userDBService = new UserDBService(ds);
|
||||
final var serverDBService = new PostgresServerDBService(ds);
|
||||
final var userDBService = new PostgresUserDBService(ds);
|
||||
|
||||
var discordBot = new DiscordBot(serverDBService, userDBService, chatGPTService, discordApiKey);
|
||||
final var discordBot = new DiscordBot(serverDBService, userDBService, chatGPTService, discordApiKey);
|
||||
discordBot.run();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import de.hhhammer.dchat.db.UserDBService;
|
|||
import org.javacord.api.DiscordApi;
|
||||
import org.javacord.api.DiscordApiBuilder;
|
||||
import org.javacord.api.interaction.SlashCommand;
|
||||
import org.javacord.api.interaction.SlashCommandInteraction;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
@ -16,7 +17,7 @@ import java.util.concurrent.CompletableFuture;
|
|||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
public class DiscordBot implements Runnable {
|
||||
public final class DiscordBot implements Runnable {
|
||||
private static final Logger logger = LoggerFactory.getLogger(DiscordBot.class);
|
||||
|
||||
private final ServerDBService serverDBService;
|
||||
|
@ -24,7 +25,7 @@ public class DiscordBot implements Runnable {
|
|||
private final ChatGPTService chatGPTService;
|
||||
private final String discordApiKey;
|
||||
|
||||
public DiscordBot(ServerDBService serverDBService, UserDBService userDBService, ChatGPTService chatGPTService, String discordApiKey) {
|
||||
public DiscordBot(final ServerDBService serverDBService, final UserDBService userDBService, final ChatGPTService chatGPTService, final String discordApiKey) {
|
||||
this.serverDBService = serverDBService;
|
||||
this.userDBService = userDBService;
|
||||
this.chatGPTService = chatGPTService;
|
||||
|
@ -34,29 +35,29 @@ public class DiscordBot implements Runnable {
|
|||
@Override
|
||||
public void run() {
|
||||
logger.info("Starting Discord application");
|
||||
DiscordApi discordApi = new DiscordApiBuilder()
|
||||
final DiscordApi discordApi = new DiscordApiBuilder()
|
||||
.setToken(discordApiKey)
|
||||
.login()
|
||||
.join();
|
||||
discordApi.setMessageCacheSize(10, 60*60);
|
||||
var future = new CompletableFuture<Void>();
|
||||
discordApi.setMessageCacheSize(10, 60 * 60);
|
||||
final var future = new CompletableFuture<Void>();
|
||||
Runtime.getRuntime().addShutdownHook(Thread.ofVirtual().unstarted(() -> {
|
||||
logger.info("Shutting down Discord application");
|
||||
discordApi.disconnect().thenAccept(future::complete);
|
||||
}));
|
||||
var token = SlashCommand.with("tokens", "Check how many tokens where spend on this server")
|
||||
final SlashCommand token = SlashCommand.with("tokens", "Check how many tokens where spend on this server")
|
||||
.createGlobal(discordApi)
|
||||
.join();
|
||||
|
||||
discordApi.addSlashCommandCreateListener(event -> {
|
||||
logger.debug("Event? " + event.getSlashCommandInteraction().getFullCommandName());
|
||||
var command = event.getSlashCommandInteraction();
|
||||
final SlashCommandInteraction command = event.getSlashCommandInteraction();
|
||||
if (token.getFullCommandNames().contains(command.getFullCommandName())) {
|
||||
event.getInteraction()
|
||||
.respondLater()
|
||||
.orTimeout(30, TimeUnit.SECONDS)
|
||||
.thenAccept((interactionOriginalResponseUpdater) -> {
|
||||
var tokens = event.getInteraction().getServer().isPresent() ?
|
||||
final long tokens = event.getInteraction().getServer().isPresent() ?
|
||||
this.serverDBService.tokensOfLast30Days(String.valueOf(event.getInteraction().getServer().get().getId())) :
|
||||
this.userDBService.tokensOfLast30Days(String.valueOf(event.getInteraction().getUser().getId()));
|
||||
interactionOriginalResponseUpdater.setContent("" + tokens).update();
|
||||
|
|
|
@ -9,17 +9,17 @@ import org.slf4j.LoggerFactory;
|
|||
|
||||
import java.io.IOException;
|
||||
|
||||
public class MessageCreateHandler implements MessageCreateListener {
|
||||
public final class MessageCreateHandler implements MessageCreateListener {
|
||||
private static final Logger logger = LoggerFactory.getLogger(MessageCreateHandler.class);
|
||||
|
||||
private final MessageHandler messageHandler;
|
||||
|
||||
public MessageCreateHandler(MessageHandler messageHandler) {
|
||||
public MessageCreateHandler(final MessageHandler messageHandler) {
|
||||
this.messageHandler = messageHandler;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onMessageCreate(MessageCreateEvent event) {
|
||||
public void onMessageCreate(final MessageCreateEvent event) {
|
||||
Thread.ofVirtual().start(() -> {
|
||||
if (!event.canYouReadContent() || event.getMessageAuthor().isBotUser() || !isNormalOrReplyMessageType(event)) {
|
||||
return;
|
||||
|
@ -36,15 +36,15 @@ public class MessageCreateHandler implements MessageCreateListener {
|
|||
}
|
||||
try {
|
||||
this.messageHandler.handle(event);
|
||||
} catch (ResponseException | IOException | InterruptedException e) {
|
||||
} catch (final ResponseException | IOException | InterruptedException e) {
|
||||
logger.error("Reading a message from the listener", e);
|
||||
event.getMessage().reply("Sorry but something went wrong :(");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private boolean isNormalOrReplyMessageType(MessageCreateEvent event) {
|
||||
MessageType type = event.getMessage().getType();
|
||||
private boolean isNormalOrReplyMessageType(final MessageCreateEvent event) {
|
||||
final MessageType type = event.getMessage().getType();
|
||||
return type == MessageType.NORMAL || type == MessageType.REPLY;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,10 @@ import de.hhhammer.dchat.bot.openai.ChatGPTRequestBuilder;
|
|||
import de.hhhammer.dchat.bot.openai.ChatGPTService;
|
||||
import de.hhhammer.dchat.bot.openai.MessageContext.ReplyInteraction;
|
||||
import de.hhhammer.dchat.bot.openai.ResponseException;
|
||||
import de.hhhammer.dchat.bot.openai.models.ChatGPTRequest;
|
||||
import de.hhhammer.dchat.bot.openai.models.ChatGPTResponse;
|
||||
import de.hhhammer.dchat.db.ServerDBService;
|
||||
import de.hhhammer.dchat.db.models.server.ServerConfig;
|
||||
import de.hhhammer.dchat.db.models.server.ServerMessage;
|
||||
import org.javacord.api.entity.DiscordEntity;
|
||||
import org.javacord.api.entity.message.Message;
|
||||
|
@ -17,43 +20,44 @@ import org.slf4j.LoggerFactory;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
public class ServerMessageHandler implements MessageHandler {
|
||||
public final class ServerMessageHandler implements MessageHandler {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(ServerMessageHandler.class);
|
||||
private final ServerDBService serverDBService;
|
||||
private final ChatGPTService chatGPTService;
|
||||
|
||||
public ServerMessageHandler(ServerDBService serverDBService, ChatGPTService chatGPTService) {
|
||||
public ServerMessageHandler(final ServerDBService serverDBService, final ChatGPTService chatGPTService) {
|
||||
this.serverDBService = serverDBService;
|
||||
this.chatGPTService = chatGPTService;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handle(MessageCreateEvent event) throws ResponseException, IOException, InterruptedException {
|
||||
String content = extractContent(event);
|
||||
var serverId = event.getServer().get().getId();
|
||||
var systemMessage = this.serverDBService.getConfig(String.valueOf(serverId)).get().systemMessage();
|
||||
List<ReplyInteraction> messageContext = event.getMessage().getType() == MessageType.REPLY ? getContextMessages(event) : List.of();
|
||||
var request = new ChatGPTRequestBuilder().contextRequest(messageContext, content, systemMessage);
|
||||
var response = this.chatGPTService.submit(request);
|
||||
public void handle(final MessageCreateEvent event) throws ResponseException, IOException, InterruptedException {
|
||||
final String content = extractContent(event);
|
||||
final long serverId = event.getServer().get().getId();
|
||||
final String systemMessage = this.serverDBService.getConfig(String.valueOf(serverId)).get().systemMessage();
|
||||
final List<ReplyInteraction> messageContext = event.getMessage().getType() == MessageType.REPLY ? getContextMessages(event) : List.of();
|
||||
final ChatGPTRequest request = new ChatGPTRequestBuilder().contextRequest(messageContext, content, systemMessage);
|
||||
final ChatGPTResponse response = this.chatGPTService.submit(request);
|
||||
if (response.choices().isEmpty()) {
|
||||
event.getMessage().reply("No response available");
|
||||
return;
|
||||
}
|
||||
var answer = response.choices().get(0).message().content();
|
||||
final String answer = response.choices().get(0).message().content();
|
||||
logServerMessage(event, response.usage().totalTokens());
|
||||
event.getMessage().reply(answer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isAllowed(MessageCreateEvent event) {
|
||||
public boolean isAllowed(final MessageCreateEvent event) {
|
||||
if (event.getServer().isEmpty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
var serverId = event.getServer().get().getId();
|
||||
var config = this.serverDBService.getConfig(String.valueOf(serverId));
|
||||
final long serverId = event.getServer().get().getId();
|
||||
final Optional<ServerConfig> config = this.serverDBService.getConfig(String.valueOf(serverId));
|
||||
if (config.isEmpty()) {
|
||||
logger.debug("Not allowed with id: " + serverId);
|
||||
return false;
|
||||
|
@ -62,39 +66,39 @@ public class ServerMessageHandler implements MessageHandler {
|
|||
}
|
||||
|
||||
@Override
|
||||
public boolean exceedsRate(MessageCreateEvent event) {
|
||||
var serverId = String.valueOf(event.getServer().get().getId());
|
||||
var config = this.serverDBService.getConfig(serverId);
|
||||
public boolean exceedsRate(final MessageCreateEvent event) {
|
||||
final String serverId = String.valueOf(event.getServer().get().getId());
|
||||
final Optional<ServerConfig> config = this.serverDBService.getConfig(serverId);
|
||||
if (config.isEmpty()) {
|
||||
logger.error("Missing configuration for server with id: " + serverId);
|
||||
return true;
|
||||
}
|
||||
var rateLimit = config.get().rateLimit();
|
||||
var countMessagesInLastMinute = this.serverDBService.countMessagesInLastMinute(serverId);
|
||||
final int rateLimit = config.get().rateLimit();
|
||||
final int countMessagesInLastMinute = this.serverDBService.countMessagesInLastMinute(serverId);
|
||||
|
||||
return countMessagesInLastMinute >= rateLimit;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean canHandle(MessageCreateEvent event) {
|
||||
public boolean canHandle(final MessageCreateEvent event) {
|
||||
return event.isServerMessage();
|
||||
}
|
||||
|
||||
private void logServerMessage(MessageCreateEvent event, int tokens) {
|
||||
var serverId = event.getServer().map(DiscordEntity::getId).get();
|
||||
var userId = event.getMessageAuthor().getId();
|
||||
private void logServerMessage(final MessageCreateEvent event, final int tokens) {
|
||||
final long serverId = event.getServer().map(DiscordEntity::getId).get();
|
||||
final long userId = event.getMessageAuthor().getId();
|
||||
|
||||
var serverMessage = new ServerMessage.NewServerMessage(String.valueOf(serverId), userId, tokens);
|
||||
final var serverMessage = new ServerMessage.NewServerMessage(String.valueOf(serverId), userId, tokens);
|
||||
this.serverDBService.addMessage(serverMessage);
|
||||
}
|
||||
|
||||
private String extractContent(MessageCreateEvent event) {
|
||||
long ownId = event.getApi().getYourself().getId();
|
||||
private String extractContent(final MessageCreateEvent event) {
|
||||
final long ownId = event.getApi().getYourself().getId();
|
||||
return event.getMessageContent().replaceFirst("<" + ownId + "> ", "");
|
||||
}
|
||||
|
||||
@NotNull
|
||||
private List<ReplyInteraction> getContextMessages(MessageCreateEvent event) {
|
||||
private List<ReplyInteraction> getContextMessages(final MessageCreateEvent event) {
|
||||
return event.getMessage()
|
||||
.getMessageReference()
|
||||
.map(MessageReference::getMessage)
|
||||
|
|
|
@ -4,7 +4,10 @@ import de.hhhammer.dchat.bot.openai.ChatGPTRequestBuilder;
|
|||
import de.hhhammer.dchat.bot.openai.ChatGPTService;
|
||||
import de.hhhammer.dchat.bot.openai.MessageContext.PreviousInteraction;
|
||||
import de.hhhammer.dchat.bot.openai.ResponseException;
|
||||
import de.hhhammer.dchat.bot.openai.models.ChatGPTRequest;
|
||||
import de.hhhammer.dchat.bot.openai.models.ChatGPTResponse;
|
||||
import de.hhhammer.dchat.db.UserDBService;
|
||||
import de.hhhammer.dchat.db.models.user.UserConfig;
|
||||
import de.hhhammer.dchat.db.models.user.UserMessage;
|
||||
import org.javacord.api.event.message.MessageCreateEvent;
|
||||
import org.slf4j.Logger;
|
||||
|
@ -12,46 +15,47 @@ import org.slf4j.LoggerFactory;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
public class UserMessageHandler implements MessageHandler {
|
||||
public final class UserMessageHandler implements MessageHandler {
|
||||
private static final Logger logger = LoggerFactory.getLogger(UserMessageHandler.class);
|
||||
private final UserDBService userDBService;
|
||||
private final ChatGPTService chatGPTService;
|
||||
|
||||
public UserMessageHandler(UserDBService userDBService, ChatGPTService chatGPTService) {
|
||||
public UserMessageHandler(final UserDBService userDBService, final ChatGPTService chatGPTService) {
|
||||
this.userDBService = userDBService;
|
||||
this.chatGPTService = chatGPTService;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handle(MessageCreateEvent event) throws ResponseException, IOException, InterruptedException {
|
||||
String content = event.getReadableMessageContent();
|
||||
var userId = String.valueOf(event.getMessageAuthor().getId());
|
||||
var config = this.userDBService.getConfig(userId).get();
|
||||
var systemMessage = config.systemMessage();
|
||||
List<PreviousInteraction> context = this.userDBService.getLastMessages(userId, config.contextLength())
|
||||
public void handle(final MessageCreateEvent event) throws ResponseException, IOException, InterruptedException {
|
||||
final String content = event.getReadableMessageContent();
|
||||
final String userId = String.valueOf(event.getMessageAuthor().getId());
|
||||
final UserConfig config = this.userDBService.getConfig(userId).get();
|
||||
final String systemMessage = config.systemMessage();
|
||||
final List<PreviousInteraction> context = this.userDBService.getLastMessages(userId, config.contextLength())
|
||||
.stream()
|
||||
.map(userMessage -> new PreviousInteraction(userMessage.question(), userMessage.answer()))
|
||||
.toList();
|
||||
var request = new ChatGPTRequestBuilder().contextRequest(context, content, systemMessage);
|
||||
var response = this.chatGPTService.submit(request);
|
||||
final ChatGPTRequest request = new ChatGPTRequestBuilder().contextRequest(context, content, systemMessage);
|
||||
final ChatGPTResponse response = this.chatGPTService.submit(request);
|
||||
if (response.choices().isEmpty()) {
|
||||
event.getMessage().reply("No response available");
|
||||
return;
|
||||
}
|
||||
var answer = response.choices().get(0).message().content();
|
||||
final String answer = response.choices().get(0).message().content();
|
||||
logUserMessage(event, content, answer, response.usage().totalTokens());
|
||||
event.getChannel().sendMessage(answer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isAllowed(MessageCreateEvent event) {
|
||||
public boolean isAllowed(final MessageCreateEvent event) {
|
||||
if (event.getServer().isPresent()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
var userId = event.getMessageAuthor().getId();
|
||||
var config = this.userDBService.getConfig(String.valueOf(userId));
|
||||
final long userId = event.getMessageAuthor().getId();
|
||||
final Optional<UserConfig> config = this.userDBService.getConfig(String.valueOf(userId));
|
||||
if (config.isEmpty()) {
|
||||
logger.debug("Not allowed with id: " + userId);
|
||||
return false;
|
||||
|
@ -60,28 +64,28 @@ public class UserMessageHandler implements MessageHandler {
|
|||
}
|
||||
|
||||
@Override
|
||||
public boolean exceedsRate(MessageCreateEvent event) {
|
||||
var userId = String.valueOf(event.getMessageAuthor().getId());
|
||||
var config = this.userDBService.getConfig(userId);
|
||||
public boolean exceedsRate(final MessageCreateEvent event) {
|
||||
final String userId = String.valueOf(event.getMessageAuthor().getId());
|
||||
final Optional<UserConfig> config = this.userDBService.getConfig(userId);
|
||||
if (config.isEmpty()) {
|
||||
logger.error("Missing configuration for userId with id: " + userId);
|
||||
return true;
|
||||
}
|
||||
var rateLimit = config.get().rateLimit();
|
||||
var countMessagesInLastMinute = this.userDBService.countMessagesInLastMinute(userId);
|
||||
final int rateLimit = config.get().rateLimit();
|
||||
final int countMessagesInLastMinute = this.userDBService.countMessagesInLastMinute(userId);
|
||||
|
||||
return countMessagesInLastMinute >= rateLimit;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean canHandle(MessageCreateEvent event) {
|
||||
public boolean canHandle(final MessageCreateEvent event) {
|
||||
return event.isPrivateMessage();
|
||||
}
|
||||
|
||||
private void logUserMessage(MessageCreateEvent event, String question, String answer, int tokens) {
|
||||
var userId = event.getMessageAuthor().getId();
|
||||
private void logUserMessage(final MessageCreateEvent event, final String question, final String answer, final int tokens) {
|
||||
final long userId = event.getMessageAuthor().getId();
|
||||
|
||||
var userMessage = new UserMessage.NewUserMessage(String.valueOf(userId), question, answer, tokens);
|
||||
final UserMessage.NewUserMessage userMessage = new UserMessage.NewUserMessage(String.valueOf(userId), question, answer, tokens);
|
||||
this.userDBService.addMessage(userMessage);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,29 +8,19 @@ import org.jetbrains.annotations.NotNull;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class ChatGPTRequestBuilder {
|
||||
public final class ChatGPTRequestBuilder {
|
||||
private static final String smallContextModel = "gpt-3.5-turbo";
|
||||
private static final String bigContextModel = "gpt-3.5-turbo-16k";
|
||||
|
||||
public ChatGPTRequestBuilder() {
|
||||
}
|
||||
|
||||
public ChatGPTRequest contextRequest(List<? extends MessageContext> contextMessages, String message, String systemMessage) {
|
||||
List<ChatGPTRequest.Message> messages = getMessages(contextMessages, message, systemMessage);
|
||||
String contextModel = contextMessages.size() <= 1 ? smallContextModel : bigContextModel;
|
||||
return new ChatGPTRequest(
|
||||
contextModel,
|
||||
messages,
|
||||
0.7f
|
||||
);
|
||||
}
|
||||
|
||||
@NotNull
|
||||
private static List<ChatGPTRequest.Message> getMessages(List<? extends MessageContext> contextMessages, String message, String systemMessage) {
|
||||
var systemMsg = new ChatGPTRequest.Message("system", systemMessage);
|
||||
List<ChatGPTRequest.Message> contextMsgs = getContextMessages(contextMessages);
|
||||
var userMessage = new ChatGPTRequest.Message("user", message);
|
||||
List<ChatGPTRequest.Message> messages = new ArrayList<>();
|
||||
private static List<ChatGPTRequest.Message> getMessages(final List<? extends MessageContext> contextMessages, final String message, final String systemMessage) {
|
||||
final ChatGPTRequest.Message systemMsg = new ChatGPTRequest.Message("system", systemMessage);
|
||||
final List<ChatGPTRequest.Message> contextMsgs = getContextMessages(contextMessages);
|
||||
final ChatGPTRequest.Message userMessage = new ChatGPTRequest.Message("user", message);
|
||||
final List<ChatGPTRequest.Message> messages = new ArrayList<>();
|
||||
messages.add(systemMsg);
|
||||
messages.addAll(contextMsgs);
|
||||
messages.add(userMessage);
|
||||
|
@ -38,7 +28,7 @@ public class ChatGPTRequestBuilder {
|
|||
}
|
||||
|
||||
@NotNull
|
||||
private static List<ChatGPTRequest.Message> getContextMessages(List<? extends MessageContext> contextMessages) {
|
||||
private static List<ChatGPTRequest.Message> getContextMessages(final List<? extends MessageContext> contextMessages) {
|
||||
return contextMessages.stream()
|
||||
.map(ChatGPTRequestBuilder::mapContextMessages)
|
||||
.flatMap(List::stream)
|
||||
|
@ -46,7 +36,7 @@ public class ChatGPTRequestBuilder {
|
|||
}
|
||||
|
||||
@NotNull
|
||||
private static List<ChatGPTRequest.Message> mapContextMessages(MessageContext contextMessage) {
|
||||
private static List<ChatGPTRequest.Message> mapContextMessages(final MessageContext contextMessage) {
|
||||
return switch (contextMessage) {
|
||||
case PreviousInteraction previousInteractions -> List.of(
|
||||
new ChatGPTRequest.Message("user", previousInteractions.question()),
|
||||
|
@ -56,4 +46,14 @@ public class ChatGPTRequestBuilder {
|
|||
List.of(new ChatGPTRequest.Message("assistant", replyInteractions.answer()));
|
||||
};
|
||||
}
|
||||
|
||||
public ChatGPTRequest contextRequest(final List<? extends MessageContext> contextMessages, final String message, final String systemMessage) {
|
||||
final List<ChatGPTRequest.Message> messages = getMessages(contextMessages, message, systemMessage);
|
||||
final String contextModel = contextMessages.size() <= 1 ? smallContextModel : bigContextModel;
|
||||
return new ChatGPTRequest(
|
||||
contextModel,
|
||||
messages,
|
||||
0.7f
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,34 +5,35 @@ import de.hhhammer.dchat.bot.openai.models.ChatGPTRequest;
|
|||
import de.hhhammer.dchat.bot.openai.models.ChatGPTResponse;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.net.URI;
|
||||
import java.net.http.HttpClient;
|
||||
import java.net.http.HttpRequest;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.time.Duration;
|
||||
|
||||
public class ChatGPTService {
|
||||
public final class ChatGPTService {
|
||||
private static final String url = "https://api.openai.com/v1/chat/completions";
|
||||
private final String apiKey;
|
||||
private final HttpClient httpClient;
|
||||
private final ObjectMapper mapper;
|
||||
|
||||
public ChatGPTService(String apiKey, HttpClient httpClient, ObjectMapper mapper) {
|
||||
public ChatGPTService(final String apiKey, final HttpClient httpClient, final ObjectMapper mapper) {
|
||||
this.apiKey = apiKey;
|
||||
this.httpClient = httpClient;
|
||||
this.mapper = mapper;
|
||||
}
|
||||
|
||||
public ChatGPTResponse submit(ChatGPTRequest chatGPTRequest) throws IOException, InterruptedException, ResponseException {
|
||||
var data = mapper.writeValueAsBytes(chatGPTRequest);
|
||||
var request = HttpRequest.newBuilder(URI.create(url))
|
||||
public ChatGPTResponse submit(final ChatGPTRequest chatGPTRequest) throws IOException, InterruptedException, ResponseException {
|
||||
final byte[] data = mapper.writeValueAsBytes(chatGPTRequest);
|
||||
final HttpRequest request = HttpRequest.newBuilder(URI.create(url))
|
||||
.POST(HttpRequest.BodyPublishers.ofByteArray(data))
|
||||
.setHeader("Content-Type", "application/json")
|
||||
.setHeader("Authorization", "Bearer " + this.apiKey)
|
||||
.timeout(Duration.ofMinutes(5))
|
||||
.build();
|
||||
|
||||
var responseStream = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||
final HttpResponse<InputStream> responseStream = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||
if (responseStream.statusCode() != 200) {
|
||||
throw new ResponseException("Response status code was not 200: " + responseStream.statusCode());
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package de.hhhammer.dchat.bot.openai;
|
||||
|
||||
public class ResponseException extends Exception {
|
||||
public ResponseException(String message) {
|
||||
public final class ResponseException extends Exception {
|
||||
public ResponseException(final String message) {
|
||||
super(message);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package de.hhhammer.dchat.db;
|
||||
|
||||
public class DBException extends Exception {
|
||||
public DBException(String message, Throwable cause) {
|
||||
public final class DBException extends Exception {
|
||||
public DBException(final String message, final Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,202 @@
|
|||
package de.hhhammer.dchat.db;
|
||||
|
||||
import de.hhhammer.dchat.db.models.server.ServerConfig;
|
||||
import de.hhhammer.dchat.db.models.server.ServerMessage;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import javax.sql.DataSource;
|
||||
import java.sql.*;
|
||||
import java.time.Instant;
|
||||
import java.time.temporal.ChronoUnit;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.StreamSupport;
|
||||
|
||||
public final class PostgresServerDBService implements ServerDBService {
|
||||
private static final Logger logger = LoggerFactory.getLogger(PostgresServerDBService.class);
|
||||
private final DataSource dataSource;
|
||||
|
||||
public PostgresServerDBService(final DataSource dataSource) {
|
||||
this.dataSource = dataSource;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<ServerConfig> getConfig(final String serverId) {
|
||||
final var getServerConfig = """
|
||||
SELECT * FROM server_configs WHERE server_id = ?
|
||||
""";
|
||||
try (final Connection con = dataSource.getConnection();
|
||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setString(1, serverId);
|
||||
final ResultSet resultSet = pstmt.executeQuery();
|
||||
final Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
|
||||
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
|
||||
} catch (final SQLException e) {
|
||||
logger.error("Getting configuration for server with id: " + serverId, e);
|
||||
} catch (final ResultSetIteratorException e) {
|
||||
logger.error("Iterating over ServerConfig ResultSet for server with id: " + serverId, e);
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ServerConfig> getAllConfigs() throws DBException {
|
||||
final var getAllowedServerSql = """
|
||||
SELECT * FROM server_configs
|
||||
""";
|
||||
try (final Connection con = dataSource.getConnection();
|
||||
final PreparedStatement pstmt = con.prepareStatement(getAllowedServerSql)
|
||||
) {
|
||||
final ResultSet resultSet = pstmt.executeQuery();
|
||||
final Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
|
||||
return StreamSupport.stream(iterable.spliterator(), false).toList();
|
||||
} catch (final SQLException e) {
|
||||
throw new DBException("Loading all configs", e);
|
||||
} catch (final ResultSetIteratorException e) {
|
||||
throw new DBException("Iterating over configs", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<ServerConfig> getConfigBy(final long id) throws DBException {
|
||||
final var getServerConfig = """
|
||||
SELECT * FROM server_configs WHERE id = ?
|
||||
""";
|
||||
try (final Connection con = dataSource.getConnection();
|
||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setLong(1, id);
|
||||
final ResultSet resultSet = pstmt.executeQuery();
|
||||
final Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
|
||||
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
|
||||
} catch (SQLException e) {
|
||||
throw new DBException("Getting configuration with id: " + id, e);
|
||||
} catch (ResultSetIteratorException e) {
|
||||
throw new DBException("Iterating over ServerConfig ResultSet for id: " + id, e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addConfig(final ServerConfig.NewServerConfig newServerConfig) throws DBException {
|
||||
final var getServerConfig = """
|
||||
INSERT INTO server_configs (server_id, system_message, rate_limit) VALUES (?,?,?)
|
||||
""";
|
||||
try (final Connection con = dataSource.getConnection();
|
||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setString(1, newServerConfig.serverId());
|
||||
pstmt.setString(2, newServerConfig.systemMessage());
|
||||
pstmt.setInt(3, newServerConfig.rateLimit());
|
||||
final int affectedRows = pstmt.executeUpdate();
|
||||
if (affectedRows == 0) {
|
||||
logger.error("No config added for server with id: " + newServerConfig.serverId());
|
||||
}
|
||||
} catch (final SQLException e) {
|
||||
throw new DBException("Adding configuration to server with id: " + newServerConfig.serverId(), e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateConfig(final long id, final ServerConfig.NewServerConfig newServerConfig) throws DBException {
|
||||
final var getServerConfig = """
|
||||
UPDATE server_configs SET system_message = ?, rate_limit = ?, server_id = ? WHERE id = ?
|
||||
""";
|
||||
try (final Connection con = dataSource.getConnection();
|
||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setString(1, newServerConfig.systemMessage());
|
||||
pstmt.setInt(2, newServerConfig.rateLimit());
|
||||
pstmt.setString(3, newServerConfig.serverId());
|
||||
pstmt.setLong(4, id);
|
||||
final int affectedRows = pstmt.executeUpdate();
|
||||
if (affectedRows == 0) {
|
||||
logger.error("No config update for server with id: " + newServerConfig.serverId());
|
||||
}
|
||||
} catch (final SQLException e) {
|
||||
throw new DBException("Adding configuration to server with id: " + newServerConfig.serverId(), e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteConfig(final long id) throws DBException {
|
||||
final var getServerConfig = """
|
||||
DELETE FROM server_configs WHERE id = ?
|
||||
""";
|
||||
try (final Connection con = dataSource.getConnection();
|
||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setLong(1, id);
|
||||
final int affectedRows = pstmt.executeUpdate();
|
||||
if (affectedRows == 0) {
|
||||
logger.error("No config deleted for server with id: " + id);
|
||||
}
|
||||
} catch (final SQLException e) {
|
||||
throw new DBException("Deleting configuration for server with id: " + id, e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int countMessagesInLastMinute(final String serverId) {
|
||||
final var getServerConfig = """
|
||||
SELECT count(*) FROM server_messages WHERE server_id = ? AND time <= ? and time >= ?
|
||||
""";
|
||||
try (final Connection con = dataSource.getConnection();
|
||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setString(1, serverId);
|
||||
final var now = Instant.now();
|
||||
pstmt.setTimestamp(2, Timestamp.from(now));
|
||||
pstmt.setTimestamp(3, Timestamp.from(now.minus(1, ChronoUnit.MINUTES)));
|
||||
final ResultSet resultSet = pstmt.executeQuery();
|
||||
if (resultSet.next()) return resultSet.getInt(1);
|
||||
} catch (final SQLException e) {
|
||||
logger.error("Getting messages for server with id: " + serverId, e);
|
||||
} catch (final ResultSetIteratorException e) {
|
||||
logger.error("Iterating over ServerMessages ResultSet for server with id: " + serverId, e);
|
||||
}
|
||||
return Integer.MAX_VALUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addMessage(final ServerMessage.NewServerMessage serverMessage) {
|
||||
final var getServerConfig = """
|
||||
INSERT INTO server_messages (server_id, user_id, tokens) VALUES (?,?,?)
|
||||
""";
|
||||
try (final Connection con = dataSource.getConnection();
|
||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setString(1, serverMessage.serverId());
|
||||
pstmt.setLong(2, serverMessage.userId());
|
||||
pstmt.setInt(3, serverMessage.tokens());
|
||||
final int affectedRows = pstmt.executeUpdate();
|
||||
if (affectedRows == 0) {
|
||||
logger.error("No message added for server with id: " + serverMessage.serverId());
|
||||
}
|
||||
} catch (final SQLException e) {
|
||||
logger.error("Adding message to server with id: " + serverMessage.serverId(), e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public long tokensOfLast30Days(final String serverId) {
|
||||
final var countTokensOfLast30Days = """
|
||||
SELECT sum(tokens) FROM server_messages WHERE server_id = ? AND time < ? AND time >= ?
|
||||
""";
|
||||
try (final Connection con = dataSource.getConnection();
|
||||
final PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
|
||||
) {
|
||||
pstmt.setString(1, serverId);
|
||||
final var now = Instant.now();
|
||||
pstmt.setTimestamp(2, Timestamp.from(now));
|
||||
pstmt.setTimestamp(3, Timestamp.from(now.minus(30, ChronoUnit.DAYS)));
|
||||
final ResultSet resultSet = pstmt.executeQuery();
|
||||
if (resultSet.next()) return resultSet.getLong(1);
|
||||
} catch (final SQLException e) {
|
||||
logger.error("Counting tokens of the last 30 days from server with id: " + serverId, e);
|
||||
}
|
||||
logger.error("No tokens found for server with id: " + serverId);
|
||||
return 0;
|
||||
}
|
||||
}
|
226
db/src/main/java/de/hhhammer/dchat/db/PostgresUserDBService.java
Normal file
226
db/src/main/java/de/hhhammer/dchat/db/PostgresUserDBService.java
Normal file
|
@ -0,0 +1,226 @@
|
|||
package de.hhhammer.dchat.db;
|
||||
|
||||
import de.hhhammer.dchat.db.models.user.UserConfig;
|
||||
import de.hhhammer.dchat.db.models.user.UserMessage;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import javax.sql.DataSource;
|
||||
import java.sql.*;
|
||||
import java.time.Instant;
|
||||
import java.time.temporal.ChronoUnit;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.StreamSupport;
|
||||
|
||||
public final class PostgresUserDBService implements UserDBService {
|
||||
private static final Logger logger = LoggerFactory.getLogger(PostgresUserDBService.class);
|
||||
private final DataSource dataSource;
|
||||
|
||||
public PostgresUserDBService(final DataSource dataSource) {
|
||||
this.dataSource = dataSource;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<UserConfig> getConfig(final String userId) {
|
||||
final var getServerConfig = """
|
||||
SELECT * FROM user_configs WHERE user_id = ?
|
||||
""";
|
||||
try (final Connection con = dataSource.getConnection();
|
||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setString(1, userId);
|
||||
final ResultSet resultSet = pstmt.executeQuery();
|
||||
final Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
|
||||
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
|
||||
} catch (final SQLException e) {
|
||||
logger.error("Getting configuration for user with id: " + userId, e);
|
||||
} catch (final ResultSetIteratorException e) {
|
||||
logger.error("Iterating over ServerConfig ResultSet for user with id: " + userId, e);
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<UserConfig> getConfigBy(final long id) throws DBException {
|
||||
final var getServerConfig = """
|
||||
SELECT * FROM user_configs WHERE id = ?
|
||||
""";
|
||||
try (final Connection con = dataSource.getConnection();
|
||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setLong(1, id);
|
||||
final ResultSet resultSet = pstmt.executeQuery();
|
||||
final Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
|
||||
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
|
||||
} catch (final SQLException e) {
|
||||
throw new DBException("Getting configuration id: " + id, e);
|
||||
} catch (final ResultSetIteratorException e) {
|
||||
throw new DBException("Iterating over UserConfig ResultSet with id: " + id, e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<UserConfig> getAllConfigs() throws DBException {
|
||||
final var getServerConfig = """
|
||||
SELECT * FROM user_configs
|
||||
""";
|
||||
try (final Connection con = dataSource.getConnection();
|
||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
final ResultSet resultSet = pstmt.executeQuery();
|
||||
final Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
|
||||
return StreamSupport.stream(iterable.spliterator(), false).toList();
|
||||
} catch (final SQLException e) {
|
||||
throw new DBException("Getting all configurations", e);
|
||||
} catch (final ResultSetIteratorException e) {
|
||||
throw new DBException("Iterating over all UserConfig ResultSet", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addConfig(UserConfig.NewUserConfig newUserConfig) throws DBException {
|
||||
final var getServerConfig = """
|
||||
INSERT INTO user_configs (user_id, system_message, context_length, rate_limit) VALUES (?,?,?,?)
|
||||
""";
|
||||
try (final Connection con = dataSource.getConnection();
|
||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setString(1, newUserConfig.userId());
|
||||
pstmt.setString(2, newUserConfig.systemMessage());
|
||||
pstmt.setInt(3, newUserConfig.contextLength());
|
||||
pstmt.setInt(4, newUserConfig.rateLimit());
|
||||
final int affectedRows = pstmt.executeUpdate();
|
||||
if (affectedRows == 0) {
|
||||
logger.error("No config added for user with id: " + newUserConfig.userId());
|
||||
}
|
||||
} catch (final SQLException e) {
|
||||
throw new DBException("Adding configuration for user with id: " + newUserConfig.userId(), e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateConfig(final long id, final UserConfig.NewUserConfig newUserConfig) throws DBException {
|
||||
final var getServerConfig = """
|
||||
UPDATE user_configs SET system_message = ?, context_length = ?, rate_limit = ?, user_id = ? WHERE id = ?
|
||||
""";
|
||||
try (final Connection con = dataSource.getConnection();
|
||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setString(1, newUserConfig.systemMessage());
|
||||
pstmt.setInt(2, newUserConfig.rateLimit());
|
||||
pstmt.setLong(3, newUserConfig.contextLength());
|
||||
pstmt.setString(4, newUserConfig.userId());
|
||||
pstmt.setLong(5, id);
|
||||
final int affectedRows = pstmt.executeUpdate();
|
||||
if (affectedRows == 0) {
|
||||
logger.error("No config update with id: " + id);
|
||||
}
|
||||
} catch (final SQLException e) {
|
||||
throw new DBException("Updating configuration with id: " + id, e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteConfig(final long id) throws DBException {
|
||||
final var getServerConfig = """
|
||||
DELETE FROM user_configs WHERE id = ?
|
||||
""";
|
||||
try (final Connection con = dataSource.getConnection();
|
||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setLong(1, id);
|
||||
final int affectedRows = pstmt.executeUpdate();
|
||||
if (affectedRows == 0) {
|
||||
logger.error("No config deleted for user with id: " + id);
|
||||
}
|
||||
} catch (final SQLException e) {
|
||||
throw new DBException("Deleting configuration with id: " + id, e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int countMessagesInLastMinute(final String userId) {
|
||||
final var getServerConfig = """
|
||||
SELECT count(*) FROM user_messages WHERE user_id = ? AND time <= ? and time >= ?
|
||||
""";
|
||||
try (final Connection con = dataSource.getConnection();
|
||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setString(1, userId);
|
||||
final var now = Instant.now();
|
||||
pstmt.setTimestamp(2, Timestamp.from(now));
|
||||
pstmt.setTimestamp(3, Timestamp.from(now.minus(1, ChronoUnit.MINUTES)));
|
||||
final ResultSet resultSet = pstmt.executeQuery();
|
||||
if (resultSet.next()) return resultSet.getInt(1);
|
||||
} catch (final SQLException e) {
|
||||
logger.error("Getting messages for user with id: " + userId, e);
|
||||
} catch (final ResultSetIteratorException e) {
|
||||
logger.error("Iterating over ServerMessages ResultSet for user with id: " + userId, e);
|
||||
}
|
||||
return Integer.MAX_VALUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addMessage(final UserMessage.NewUserMessage newUserMessage) {
|
||||
final var getServerConfig = """
|
||||
INSERT INTO user_messages (user_id, question, answer, tokens) VALUES (?,?,?,?)
|
||||
""";
|
||||
try (final Connection con = dataSource.getConnection();
|
||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setString(1, newUserMessage.userId());
|
||||
pstmt.setString(2, newUserMessage.question());
|
||||
pstmt.setString(3, newUserMessage.answer());
|
||||
pstmt.setInt(4, newUserMessage.tokens());
|
||||
final int affectedRows = pstmt.executeUpdate();
|
||||
if (affectedRows == 0) {
|
||||
logger.error("No message added for user with id: " + newUserMessage.userId());
|
||||
}
|
||||
} catch (final SQLException e) {
|
||||
logger.error("Adding message to user with id: " + newUserMessage.userId(), e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<UserMessage> getLastMessages(final String userId, final int limit) {
|
||||
final var getLastMessages = """
|
||||
SELECT * FROM user_messages WHERE user_id = ? ORDER BY time DESC LIMIT ?
|
||||
""";
|
||||
try (final Connection con = dataSource.getConnection();
|
||||
final PreparedStatement pstmt = con.prepareStatement(getLastMessages)
|
||||
) {
|
||||
pstmt.setString(1, userId);
|
||||
pstmt.setInt(2, limit);
|
||||
final ResultSet resultSet = pstmt.executeQuery();
|
||||
final Iterable<UserMessage> iterable = () -> new ResultSetIterator<>(resultSet, new UserMessage.UserMessageResultSetTransformer());
|
||||
return StreamSupport.stream(iterable.spliterator(), false).toList();
|
||||
} catch (final SQLException e) {
|
||||
logger.error("Fetching last messages for user whit id: " + userId, e);
|
||||
} catch (final ResultSetIteratorException e) {
|
||||
logger.error("Iterating over messages ResultSet from user with id: " + userId, e);
|
||||
}
|
||||
return List.of();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long tokensOfLast30Days(final String userId) {
|
||||
final var countTokensOfLast30Days = """
|
||||
SELECT sum(tokens) FROM user_messages WHERE user_id = ? AND time < ? AND time >= ?
|
||||
""";
|
||||
try (final Connection con = dataSource.getConnection();
|
||||
final PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
|
||||
) {
|
||||
pstmt.setString(1, userId);
|
||||
final var now = Instant.now();
|
||||
pstmt.setTimestamp(2, Timestamp.from(now));
|
||||
pstmt.setTimestamp(3, Timestamp.from(now.minus(30, ChronoUnit.DAYS)));
|
||||
final ResultSet resultSet = pstmt.executeQuery();
|
||||
if (resultSet.next()) return resultSet.getLong(1);
|
||||
} catch (final SQLException e) {
|
||||
logger.error("Counting tokens of the last 30 days from user with id: " + userId, e);
|
||||
}
|
||||
logger.error("No tokens found for user with id: " + userId);
|
||||
return 0;
|
||||
}
|
||||
}
|
|
@ -4,11 +4,11 @@ import java.sql.ResultSet;
|
|||
import java.sql.SQLException;
|
||||
import java.util.Iterator;
|
||||
|
||||
public class ResultSetIterator<T> implements Iterator<T> {
|
||||
public final class ResultSetIterator<T> implements Iterator<T> {
|
||||
private final ResultSet resultSet;
|
||||
private final ResultSetTransformer<T> transformer;
|
||||
|
||||
public ResultSetIterator(ResultSet resultSet, ResultSetTransformer<T> transformer) {
|
||||
public ResultSetIterator(final ResultSet resultSet, final ResultSetTransformer<T> transformer) {
|
||||
this.resultSet = resultSet;
|
||||
this.transformer = transformer;
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package de.hhhammer.dchat.db;
|
||||
|
||||
public class ResultSetIteratorException extends RuntimeException {
|
||||
public ResultSetIteratorException(Throwable cause) {
|
||||
public final class ResultSetIteratorException extends RuntimeException {
|
||||
public ResultSetIteratorException(final Throwable cause) {
|
||||
super(cause);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,196 +2,26 @@ package de.hhhammer.dchat.db;
|
|||
|
||||
import de.hhhammer.dchat.db.models.server.ServerConfig;
|
||||
import de.hhhammer.dchat.db.models.server.ServerMessage;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import javax.sql.DataSource;
|
||||
import java.sql.*;
|
||||
import java.time.Instant;
|
||||
import java.time.temporal.ChronoUnit;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.StreamSupport;
|
||||
|
||||
public class ServerDBService {
|
||||
private static final Logger logger = LoggerFactory.getLogger(ServerDBService.class);
|
||||
private final DataSource dataSource;
|
||||
public interface ServerDBService {
|
||||
Optional<ServerConfig> getConfig(String serverId);
|
||||
|
||||
public ServerDBService(DataSource dataSource) {
|
||||
this.dataSource = dataSource;
|
||||
}
|
||||
List<ServerConfig> getAllConfigs() throws DBException;
|
||||
|
||||
public Optional<ServerConfig> getConfig(String serverId) {
|
||||
var getServerConfig = """
|
||||
SELECT * FROM server_configs WHERE server_id = ?
|
||||
""";
|
||||
try (Connection con = dataSource.getConnection();
|
||||
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setString(1, serverId);
|
||||
ResultSet resultSet = pstmt.executeQuery();
|
||||
Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
|
||||
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
|
||||
} catch (SQLException e) {
|
||||
logger.error("Getting configuration for server with id: " + serverId, e);
|
||||
} catch (ResultSetIteratorException e) {
|
||||
logger.error("Iterating over ServerConfig ResultSet for server with id: " + serverId, e);
|
||||
return Optional.empty();
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
Optional<ServerConfig> getConfigBy(long id) throws DBException;
|
||||
|
||||
public List<ServerConfig> getAllConfigs() throws DBException {
|
||||
var getAllowedServerSql = """
|
||||
SELECT * FROM server_configs
|
||||
""";
|
||||
try (Connection con = dataSource.getConnection();
|
||||
PreparedStatement pstmt = con.prepareStatement(getAllowedServerSql)
|
||||
) {
|
||||
ResultSet resultSet = pstmt.executeQuery();
|
||||
Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
|
||||
return StreamSupport.stream(iterable.spliterator(), false).toList();
|
||||
} catch (SQLException e) {
|
||||
throw new DBException("Loading all configs", e);
|
||||
} catch (ResultSetIteratorException e) {
|
||||
throw new DBException("Iterating over configs", e);
|
||||
}
|
||||
}
|
||||
void addConfig(ServerConfig.NewServerConfig newServerConfig) throws DBException;
|
||||
|
||||
public Optional<ServerConfig> getConfigBy(long id) throws DBException {
|
||||
var getServerConfig = """
|
||||
SELECT * FROM server_configs WHERE id = ?
|
||||
""";
|
||||
try (Connection con = dataSource.getConnection();
|
||||
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setLong(1, id);
|
||||
ResultSet resultSet = pstmt.executeQuery();
|
||||
Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
|
||||
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
|
||||
} catch (SQLException e) {
|
||||
throw new DBException("Getting configuration with id: " + id, e);
|
||||
} catch (ResultSetIteratorException e) {
|
||||
throw new DBException("Iterating over ServerConfig ResultSet for id: " + id, e);
|
||||
}
|
||||
}
|
||||
void updateConfig(long id, ServerConfig.NewServerConfig newServerConfig) throws DBException;
|
||||
|
||||
public void addConfig(ServerConfig.NewServerConfig newServerConfig) throws DBException {
|
||||
var getServerConfig = """
|
||||
INSERT INTO server_configs (server_id, system_message, rate_limit) VALUES (?,?,?)
|
||||
""";
|
||||
try (Connection con = dataSource.getConnection();
|
||||
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setString(1, newServerConfig.serverId());
|
||||
pstmt.setString(2, newServerConfig.systemMessage());
|
||||
pstmt.setInt(3, newServerConfig.rateLimit());
|
||||
int affectedRows = pstmt.executeUpdate();
|
||||
if (affectedRows == 0) {
|
||||
logger.error("No config added for server with id: " + newServerConfig.serverId());
|
||||
}
|
||||
} catch (SQLException e) {
|
||||
throw new DBException("Adding configuration to server with id: " + newServerConfig.serverId(), e);
|
||||
}
|
||||
}
|
||||
void deleteConfig(long id) throws DBException;
|
||||
|
||||
public void updateConfig(long id, ServerConfig.NewServerConfig newServerConfig) throws DBException {
|
||||
var getServerConfig = """
|
||||
UPDATE server_configs SET system_message = ?, rate_limit = ?, server_id = ? WHERE id = ?
|
||||
""";
|
||||
try (Connection con = dataSource.getConnection();
|
||||
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setString(1, newServerConfig.systemMessage());
|
||||
pstmt.setInt(2, newServerConfig.rateLimit());
|
||||
pstmt.setString(3, newServerConfig.serverId());
|
||||
pstmt.setLong(4, id);
|
||||
int affectedRows = pstmt.executeUpdate();
|
||||
if (affectedRows == 0) {
|
||||
logger.error("No config update for server with id: " + newServerConfig.serverId());
|
||||
}
|
||||
} catch (SQLException e) {
|
||||
throw new DBException("Adding configuration to server with id: " + newServerConfig.serverId(), e);
|
||||
}
|
||||
}
|
||||
int countMessagesInLastMinute(String serverId);
|
||||
|
||||
public void deleteConfig(long id) throws DBException {
|
||||
var getServerConfig = """
|
||||
DELETE FROM server_configs WHERE id = ?
|
||||
""";
|
||||
try (Connection con = dataSource.getConnection();
|
||||
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setLong(1, id);
|
||||
int affectedRows = pstmt.executeUpdate();
|
||||
if (affectedRows == 0) {
|
||||
logger.error("No config deleted for server with id: " + id);
|
||||
}
|
||||
} catch (SQLException e) {
|
||||
throw new DBException("Deleting configuration for server with id: " + id, e);
|
||||
}
|
||||
}
|
||||
void addMessage(ServerMessage.NewServerMessage serverMessage);
|
||||
|
||||
public int countMessagesInLastMinute(String serverId) {
|
||||
var getServerConfig = """
|
||||
SELECT count(*) FROM server_messages WHERE server_id = ? AND time <= ? and time >= ?
|
||||
""";
|
||||
try (Connection con = dataSource.getConnection();
|
||||
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setString(1, serverId);
|
||||
var now = Instant.now();
|
||||
pstmt.setTimestamp(2, Timestamp.from(now));
|
||||
pstmt.setTimestamp(3, Timestamp.from(now.minus(1, ChronoUnit.MINUTES)));
|
||||
ResultSet resultSet = pstmt.executeQuery();
|
||||
resultSet.next();
|
||||
return resultSet.getInt(1);
|
||||
} catch (SQLException e) {
|
||||
logger.error("Getting messages for server with id: " + serverId, e);
|
||||
} catch (ResultSetIteratorException e) {
|
||||
logger.error("Iterating over ServerMessages ResultSet for server with id: " + serverId, e);
|
||||
return Integer.MAX_VALUE;
|
||||
}
|
||||
return Integer.MAX_VALUE;
|
||||
}
|
||||
|
||||
public void addMessage(ServerMessage.NewServerMessage serverMessage) {
|
||||
var getServerConfig = """
|
||||
INSERT INTO server_messages (server_id, user_id, tokens) VALUES (?,?,?)
|
||||
""";
|
||||
try (Connection con = dataSource.getConnection();
|
||||
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setString(1, serverMessage.serverId());
|
||||
pstmt.setLong(2, serverMessage.userId());
|
||||
pstmt.setInt(3, serverMessage.tokens());
|
||||
int affectedRows = pstmt.executeUpdate();
|
||||
if (affectedRows == 0) {
|
||||
logger.error("No message added for server with id: " + serverMessage.serverId());
|
||||
}
|
||||
} catch (SQLException e) {
|
||||
logger.error("Adding message to server with id: " + serverMessage.serverId(), e);
|
||||
}
|
||||
}
|
||||
|
||||
public long tokensOfLast30Days(String serverId) {
|
||||
var countTokensOfLast30Days = """
|
||||
SELECT sum(tokens) FROM server_messages WHERE server_id = ? AND time < ? AND time >= ?
|
||||
""";
|
||||
try (Connection con = dataSource.getConnection();
|
||||
PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
|
||||
) {
|
||||
pstmt.setString(1, serverId);
|
||||
var now = Instant.now();
|
||||
pstmt.setTimestamp(2, Timestamp.from(now));
|
||||
pstmt.setTimestamp(3, Timestamp.from(now.minus(30, ChronoUnit.DAYS)));
|
||||
ResultSet resultSet = pstmt.executeQuery();
|
||||
resultSet.next();
|
||||
return resultSet.getLong(1);
|
||||
} catch (SQLException e) {
|
||||
logger.error("Counting tokens of the last 30 days from server with id: " + serverId, e);
|
||||
}
|
||||
logger.error("No tokens found for server with id: " + serverId);
|
||||
return 0;
|
||||
}
|
||||
long tokensOfLast30Days(String serverId);
|
||||
}
|
||||
|
|
|
@ -2,219 +2,28 @@ package de.hhhammer.dchat.db;
|
|||
|
||||
import de.hhhammer.dchat.db.models.user.UserConfig;
|
||||
import de.hhhammer.dchat.db.models.user.UserMessage;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import javax.sql.DataSource;
|
||||
import java.sql.*;
|
||||
import java.time.Instant;
|
||||
import java.time.temporal.ChronoUnit;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.StreamSupport;
|
||||
|
||||
public class UserDBService {
|
||||
private static final Logger logger = LoggerFactory.getLogger(UserDBService.class);
|
||||
private final DataSource dataSource;
|
||||
public interface UserDBService {
|
||||
Optional<UserConfig> getConfig(String userId);
|
||||
|
||||
public UserDBService(DataSource dataSource) {
|
||||
this.dataSource = dataSource;
|
||||
}
|
||||
Optional<UserConfig> getConfigBy(long id) throws DBException;
|
||||
|
||||
public Optional<UserConfig> getConfig(String userId) {
|
||||
var getServerConfig = """
|
||||
SELECT * FROM user_configs WHERE user_id = ?
|
||||
""";
|
||||
try (Connection con = dataSource.getConnection();
|
||||
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setString(1, userId);
|
||||
ResultSet resultSet = pstmt.executeQuery();
|
||||
Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
|
||||
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
|
||||
} catch (SQLException e) {
|
||||
logger.error("Getting configuration for user with id: " + userId, e);
|
||||
} catch (ResultSetIteratorException e) {
|
||||
logger.error("Iterating over ServerConfig ResultSet for user with id: " + userId, e);
|
||||
return Optional.empty();
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
List<UserConfig> getAllConfigs() throws DBException;
|
||||
|
||||
public Optional<UserConfig> getConfigBy(long id) throws DBException {
|
||||
var getServerConfig = """
|
||||
SELECT * FROM user_configs WHERE id = ?
|
||||
""";
|
||||
try (Connection con = dataSource.getConnection();
|
||||
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setLong(1, id);
|
||||
ResultSet resultSet = pstmt.executeQuery();
|
||||
Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
|
||||
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
|
||||
} catch (SQLException e) {
|
||||
throw new DBException("Getting configuration id: " + id, e);
|
||||
} catch (ResultSetIteratorException e) {
|
||||
throw new DBException("Iterating over UserConfig ResultSet with id: " + id, e);
|
||||
}
|
||||
}
|
||||
void addConfig(UserConfig.NewUserConfig newUserConfig) throws DBException;
|
||||
|
||||
public List<UserConfig> getAllConfigs() throws DBException {
|
||||
var getServerConfig = """
|
||||
SELECT * FROM user_configs
|
||||
""";
|
||||
try (Connection con = dataSource.getConnection();
|
||||
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
ResultSet resultSet = pstmt.executeQuery();
|
||||
Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
|
||||
return StreamSupport.stream(iterable.spliterator(), false).toList();
|
||||
} catch (SQLException e) {
|
||||
throw new DBException("Getting all configurations", e);
|
||||
} catch (ResultSetIteratorException e) {
|
||||
throw new DBException("Iterating over all UserConfig ResultSet", e);
|
||||
}
|
||||
}
|
||||
void updateConfig(long id, UserConfig.NewUserConfig newUserConfig) throws DBException;
|
||||
|
||||
public void addConfig(UserConfig.NewUserConfig newUserConfig) throws DBException {
|
||||
var getServerConfig = """
|
||||
INSERT INTO user_configs (user_id, system_message, context_length, rate_limit) VALUES (?,?,?,?)
|
||||
""";
|
||||
try (Connection con = dataSource.getConnection();
|
||||
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setString(1, newUserConfig.userId());
|
||||
pstmt.setString(2, newUserConfig.systemMessage());
|
||||
pstmt.setInt(3, newUserConfig.contextLength());
|
||||
pstmt.setInt(4, newUserConfig.rateLimit());
|
||||
int affectedRows = pstmt.executeUpdate();
|
||||
if (affectedRows == 0) {
|
||||
logger.error("No config added for user with id: " + newUserConfig.userId());
|
||||
}
|
||||
} catch (SQLException e) {
|
||||
throw new DBException("Adding configuration for user with id: " + newUserConfig.userId(), e);
|
||||
}
|
||||
}
|
||||
void deleteConfig(long id) throws DBException;
|
||||
|
||||
public void updateConfig(long id, UserConfig.NewUserConfig newUserConfig) throws DBException {
|
||||
var getServerConfig = """
|
||||
UPDATE user_configs SET system_message = ?, context_length = ?, rate_limit = ?, user_id = ? WHERE id = ?
|
||||
""";
|
||||
try (Connection con = dataSource.getConnection();
|
||||
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setString(1, newUserConfig.systemMessage());
|
||||
pstmt.setInt(2, newUserConfig.rateLimit());
|
||||
pstmt.setLong(3, newUserConfig.contextLength());
|
||||
pstmt.setString(4, newUserConfig.userId());
|
||||
pstmt.setLong(5, id);
|
||||
int affectedRows = pstmt.executeUpdate();
|
||||
if (affectedRows == 0) {
|
||||
logger.error("No config update with id: " + id);
|
||||
}
|
||||
} catch (SQLException e) {
|
||||
throw new DBException("Updating configuration with id: " + id, e);
|
||||
}
|
||||
}
|
||||
int countMessagesInLastMinute(String userId);
|
||||
|
||||
public void deleteConfig(long id) throws DBException {
|
||||
var getServerConfig = """
|
||||
DELETE FROM user_configs WHERE id = ?
|
||||
""";
|
||||
try (Connection con = dataSource.getConnection();
|
||||
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setLong(1, id);
|
||||
int affectedRows = pstmt.executeUpdate();
|
||||
if (affectedRows == 0) {
|
||||
logger.error("No config deleted for user with id: " + id);
|
||||
}
|
||||
} catch (SQLException e) {
|
||||
throw new DBException("Deleting configuration with id: " + id, e);
|
||||
}
|
||||
}
|
||||
void addMessage(UserMessage.NewUserMessage newUserMessage);
|
||||
|
||||
public int countMessagesInLastMinute(String userId) {
|
||||
var getServerConfig = """
|
||||
SELECT count(*) FROM user_messages WHERE user_id = ? AND time <= ? and time >= ?
|
||||
""";
|
||||
try (Connection con = dataSource.getConnection();
|
||||
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setString(1, userId);
|
||||
var now = Instant.now();
|
||||
pstmt.setTimestamp(2, Timestamp.from(now));
|
||||
pstmt.setTimestamp(3, Timestamp.from(now.minus(1, ChronoUnit.MINUTES)));
|
||||
ResultSet resultSet = pstmt.executeQuery();
|
||||
resultSet.next();
|
||||
return resultSet.getInt(1);
|
||||
} catch (SQLException e) {
|
||||
logger.error("Getting messages for user with id: " + userId, e);
|
||||
} catch (ResultSetIteratorException e) {
|
||||
logger.error("Iterating over ServerMessages ResultSet for user with id: " + userId, e);
|
||||
return Integer.MAX_VALUE;
|
||||
}
|
||||
return Integer.MAX_VALUE;
|
||||
}
|
||||
List<UserMessage> getLastMessages(String userId, int limit);
|
||||
|
||||
public void addMessage(UserMessage.NewUserMessage newUserMessage) {
|
||||
var getServerConfig = """
|
||||
INSERT INTO user_messages (user_id, question, answer, tokens) VALUES (?,?,?,?)
|
||||
""";
|
||||
try (Connection con = dataSource.getConnection();
|
||||
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||
) {
|
||||
pstmt.setString(1, newUserMessage.userId());
|
||||
pstmt.setString(2, newUserMessage.question());
|
||||
pstmt.setString(3, newUserMessage.answer());
|
||||
pstmt.setInt(4, newUserMessage.tokens());
|
||||
int affectedRows = pstmt.executeUpdate();
|
||||
if (affectedRows == 0) {
|
||||
logger.error("No message added for user with id: " + newUserMessage.userId());
|
||||
}
|
||||
} catch (SQLException e) {
|
||||
logger.error("Adding message to user with id: " + newUserMessage.userId(), e);
|
||||
}
|
||||
}
|
||||
|
||||
public List<UserMessage> getLastMessages(String userId, int limit) {
|
||||
var getLastMessages = """
|
||||
SELECT * FROM user_messages WHERE user_id = ? ORDER BY time DESC LIMIT ?
|
||||
""";
|
||||
try (Connection con = dataSource.getConnection();
|
||||
PreparedStatement pstmt = con.prepareStatement(getLastMessages)
|
||||
) {
|
||||
pstmt.setString(1, userId);
|
||||
pstmt.setInt(2, limit);
|
||||
ResultSet resultSet = pstmt.executeQuery();
|
||||
Iterable<UserMessage> iterable = () -> new ResultSetIterator<>(resultSet, new UserMessage.UserMessageResultSetTransformer());
|
||||
return StreamSupport.stream(iterable.spliterator(), false).toList();
|
||||
} catch (SQLException e) {
|
||||
logger.error("Fetching last messages for user whit id: " + userId, e);
|
||||
} catch (ResultSetIteratorException e) {
|
||||
logger.error("Iterating over messages ResultSet from user with id: " + userId, e);
|
||||
}
|
||||
return List.of();
|
||||
}
|
||||
|
||||
public long tokensOfLast30Days(String userId) {
|
||||
var countTokensOfLast30Days = """
|
||||
SELECT sum(tokens) FROM user_messages WHERE user_id = ? AND time < ? AND time >= ?
|
||||
""";
|
||||
try (Connection con = dataSource.getConnection();
|
||||
PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
|
||||
) {
|
||||
pstmt.setString(1, userId);
|
||||
var now = Instant.now();
|
||||
pstmt.setTimestamp(2, Timestamp.from(now));
|
||||
pstmt.setTimestamp(3, Timestamp.from(now.minus(30, ChronoUnit.DAYS)));
|
||||
ResultSet resultSet = pstmt.executeQuery();
|
||||
resultSet.next();
|
||||
return resultSet.getLong(1);
|
||||
} catch (SQLException e) {
|
||||
logger.error("Counting tokens of the last 30 days from user with id: " + userId, e);
|
||||
}
|
||||
logger.error("No tokens found for user with id: " + userId);
|
||||
return 0;
|
||||
}
|
||||
long tokensOfLast30Days(String userId);
|
||||
}
|
||||
|
|
|
@ -6,20 +6,20 @@ import org.slf4j.LoggerFactory;
|
|||
/**
|
||||
* Hello world!
|
||||
*/
|
||||
public class App {
|
||||
public final class App {
|
||||
private static final Logger logger = LoggerFactory.getLogger(App.class);
|
||||
private static final String DB_MIGRATION_PATH = "db/schema.sql";
|
||||
|
||||
public static void main(String[] args) {
|
||||
String postgresUser = System.getenv("POSTGRES_USER");
|
||||
String postgresPassword = System.getenv("POSTGRES_PASSWORD");
|
||||
String postgresUrl = System.getenv("POSTGRES_URL");
|
||||
public static void main(final String[] args) {
|
||||
final String postgresUser = System.getenv("POSTGRES_USER");
|
||||
final String postgresPassword = System.getenv("POSTGRES_PASSWORD");
|
||||
final String postgresUrl = System.getenv("POSTGRES_URL");
|
||||
if (postgresUser == null || postgresPassword == null || postgresUrl == null) {
|
||||
logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL");
|
||||
System.exit(1);
|
||||
}
|
||||
var migrationExecutor = new MigrationExecutor(postgresUrl, postgresUser, postgresPassword);
|
||||
var dbMigrator = new DBMigrator(migrationExecutor, DB_MIGRATION_PATH);
|
||||
final var migrationExecutor = new MigrationExecutor(postgresUrl, postgresUser, postgresPassword);
|
||||
final var dbMigrator = new DBMigrator(migrationExecutor, DB_MIGRATION_PATH);
|
||||
dbMigrator.run();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package de.hhhammer.dchat.migration;
|
||||
|
||||
public class DBMigrationException extends Exception {
|
||||
public DBMigrationException(Throwable cause) {
|
||||
public final class DBMigrationException extends Exception {
|
||||
public DBMigrationException(final Throwable cause) {
|
||||
super(cause);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,12 +6,12 @@ import org.slf4j.LoggerFactory;
|
|||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
|
||||
public class DBMigrator implements Runnable {
|
||||
public final class DBMigrator implements Runnable {
|
||||
private static final Logger logger = LoggerFactory.getLogger(DBMigrator.class);
|
||||
private final MigrationExecutor migrationExecutor;
|
||||
private final String resourcePath;
|
||||
|
||||
public DBMigrator(MigrationExecutor migrationExecutor, String resourcePath) {
|
||||
public DBMigrator(final MigrationExecutor migrationExecutor, final String resourcePath) {
|
||||
this.migrationExecutor = migrationExecutor;
|
||||
this.resourcePath = resourcePath;
|
||||
}
|
||||
|
@ -19,8 +19,8 @@ public class DBMigrator implements Runnable {
|
|||
@Override
|
||||
public void run() {
|
||||
logger.info("Starting db migration");
|
||||
ClassLoader classLoader = getClass().getClassLoader();
|
||||
try (InputStream inputStream = classLoader.getResourceAsStream(this.resourcePath)) {
|
||||
final ClassLoader classLoader = getClass().getClassLoader();
|
||||
try (final InputStream inputStream = classLoader.getResourceAsStream(this.resourcePath)) {
|
||||
if (inputStream == null) {
|
||||
logger.error("Migration file not found: " + resourcePath);
|
||||
throw new RuntimeException("Migration file not found");
|
||||
|
|
|
@ -11,24 +11,24 @@ import java.sql.DriverManager;
|
|||
import java.sql.SQLException;
|
||||
import java.sql.Statement;
|
||||
|
||||
public class MigrationExecutor {
|
||||
public final class MigrationExecutor {
|
||||
private static final Logger logger = LoggerFactory.getLogger(MigrationExecutor.class);
|
||||
private final String jdbcConnectionString;
|
||||
private final String username;
|
||||
private final String password;
|
||||
|
||||
public MigrationExecutor(String jdbcConnectionString, String username, String password) {
|
||||
public MigrationExecutor(final String jdbcConnectionString, final String username, final String password) {
|
||||
this.jdbcConnectionString = jdbcConnectionString;
|
||||
this.username = username;
|
||||
this.password = password;
|
||||
}
|
||||
|
||||
public void migrate(InputStream input) throws DBMigrationException {
|
||||
try (Connection con = DriverManager
|
||||
public void migrate(final InputStream input) throws DBMigrationException {
|
||||
try (final Connection con = DriverManager
|
||||
.getConnection(this.jdbcConnectionString, this.username, this.password);
|
||||
Statement stmp = con.createStatement();
|
||||
final Statement stmp = con.createStatement();
|
||||
) {
|
||||
String content = new String(input.readAllBytes(), StandardCharsets.UTF_8);
|
||||
final String content = new String(input.readAllBytes(), StandardCharsets.UTF_8);
|
||||
stmp.execute(content);
|
||||
} catch (SQLException | IOException e) {
|
||||
throw new DBMigrationException(e);
|
||||
|
|
|
@ -12,10 +12,6 @@
|
|||
<version>1.0-SNAPSHOT</version>
|
||||
<name>web</name>
|
||||
|
||||
<properties>
|
||||
<node.version>v18.16.0</node.version>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>de.hhhammer.dchat</groupId>
|
||||
|
|
|
@ -2,41 +2,41 @@ package de.hhhammer.dchat.web;
|
|||
|
||||
import com.zaxxer.hikari.HikariConfig;
|
||||
import com.zaxxer.hikari.HikariDataSource;
|
||||
import de.hhhammer.dchat.db.ServerDBService;
|
||||
import de.hhhammer.dchat.db.UserDBService;
|
||||
import de.hhhammer.dchat.db.PostgresServerDBService;
|
||||
import de.hhhammer.dchat.db.PostgresUserDBService;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* Hello world!
|
||||
*/
|
||||
public class App {
|
||||
public final class App {
|
||||
private static final Logger logger = LoggerFactory.getLogger(App.class);
|
||||
|
||||
public static void main(String[] args) {
|
||||
String postgresUser = System.getenv("POSTGRES_USER");
|
||||
String postgresPassword = System.getenv("POSTGRES_PASSWORD");
|
||||
String postgresUrl = System.getenv("POSTGRES_URL");
|
||||
public static void main(final String[] args) {
|
||||
final String postgresUser = System.getenv("POSTGRES_USER");
|
||||
final String postgresPassword = System.getenv("POSTGRES_PASSWORD");
|
||||
final String postgresUrl = System.getenv("POSTGRES_URL");
|
||||
if (postgresUser == null || postgresPassword == null || postgresUrl == null) {
|
||||
logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL");
|
||||
System.exit(1);
|
||||
}
|
||||
|
||||
String apiPortStr = System.getenv("API_PORT") != null ? System.getenv("API_PORT") : "8080";
|
||||
int apiPort = Integer.parseInt(apiPortStr);
|
||||
boolean debug = "true".equals(System.getenv("API_DEBUG"));
|
||||
final String apiPortStr = System.getenv("API_PORT") != null ? System.getenv("API_PORT") : "8080";
|
||||
final int apiPort = Integer.parseInt(apiPortStr);
|
||||
final boolean debug = "true".equals(System.getenv("API_DEBUG"));
|
||||
|
||||
var config = new HikariConfig();
|
||||
final var config = new HikariConfig();
|
||||
config.setJdbcUrl(postgresUrl);
|
||||
config.setUsername(postgresUser);
|
||||
config.setPassword(postgresPassword);
|
||||
|
||||
try (var ds = new HikariDataSource(config)) {
|
||||
var serverDBService = new ServerDBService(ds);
|
||||
var userDBService = new UserDBService(ds);
|
||||
var appConfig = new AppConfig(apiPort, debug);
|
||||
try (final var ds = new HikariDataSource(config)) {
|
||||
final var serverDBService = new PostgresServerDBService(ds);
|
||||
final var userDBService = new PostgresUserDBService(ds);
|
||||
final var appConfig = new AppConfig(apiPort, debug);
|
||||
|
||||
var webApi = new WebAPI(serverDBService, userDBService, appConfig);
|
||||
final var webApi = new WebAPI(serverDBService, userDBService, appConfig);
|
||||
webApi.run();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ import de.hhhammer.dchat.web.server.ConfigCrudHandler;
|
|||
import de.hhhammer.dchat.web.user.ConfigUserCrudHandler;
|
||||
import io.javalin.Javalin;
|
||||
import io.javalin.http.HttpStatus;
|
||||
import io.javalin.http.staticfiles.Location;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
@ -16,13 +15,13 @@ import java.util.concurrent.ExecutionException;
|
|||
import static io.javalin.apibuilder.ApiBuilder.crud;
|
||||
import static io.javalin.apibuilder.ApiBuilder.path;
|
||||
|
||||
public class WebAPI implements Runnable {
|
||||
public final class WebAPI implements Runnable {
|
||||
private static final Logger logger = LoggerFactory.getLogger(WebAPI.class);
|
||||
private final ServerDBService serverDBService;
|
||||
private final UserDBService userDBService;
|
||||
private final AppConfig appConfig;
|
||||
|
||||
public WebAPI(ServerDBService serverDBService, UserDBService userDBService, AppConfig appConfig) {
|
||||
public WebAPI(final ServerDBService serverDBService, final UserDBService userDBService, final AppConfig appConfig) {
|
||||
this.serverDBService = serverDBService;
|
||||
this.userDBService = userDBService;
|
||||
this.appConfig = appConfig;
|
||||
|
@ -31,12 +30,12 @@ public class WebAPI implements Runnable {
|
|||
@Override
|
||||
public void run() {
|
||||
logger.info("Starting web application");
|
||||
var app = Javalin.create(config -> {
|
||||
final Javalin app = Javalin.create(config -> {
|
||||
if (appConfig.debug()) config.plugins.enableDevLogging();
|
||||
config.http.prefer405over404 = true; // return 405 instead of 404 if path is mapped to different HTTP method
|
||||
config.http.defaultContentType = "application/json";
|
||||
});
|
||||
var waitForShutdown = new CompletableFuture<Void>();
|
||||
final var waitForShutdown = new CompletableFuture<Void>();
|
||||
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
||||
logger.info("Shutting down web application");
|
||||
app.stop();
|
||||
|
|
|
@ -10,18 +10,21 @@ import org.jetbrains.annotations.NotNull;
|
|||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class ConfigCrudHandler implements CrudHandler {
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
public final class ConfigCrudHandler implements CrudHandler {
|
||||
private static final Logger logger = LoggerFactory.getLogger(ConfigCrudHandler.class);
|
||||
|
||||
private final ServerDBService serverDBService;
|
||||
|
||||
public ConfigCrudHandler(ServerDBService serverDBService) {
|
||||
public ConfigCrudHandler(final ServerDBService serverDBService) {
|
||||
this.serverDBService = serverDBService;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void create(@NotNull Context context) {
|
||||
var body = context.bodyAsClass(ServerConfig.NewServerConfig.class);
|
||||
public void create(@NotNull final Context context) {
|
||||
final ServerConfig.NewServerConfig body = context.bodyAsClass(ServerConfig.NewServerConfig.class);
|
||||
try {
|
||||
this.serverDBService.addConfig(body);
|
||||
} catch (DBException e) {
|
||||
|
@ -32,9 +35,10 @@ public class ConfigCrudHandler implements CrudHandler {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void delete(@NotNull Context context, @NotNull String s) {
|
||||
public void delete(@NotNull final Context context, @NotNull final String s) {
|
||||
final var id = Long.parseLong(s);
|
||||
try {
|
||||
this.serverDBService.deleteConfig(Long.parseLong(s));
|
||||
this.serverDBService.deleteConfig(id);
|
||||
context.status(HttpStatus.NO_CONTENT);
|
||||
} catch (DBException e) {
|
||||
logger.error("Deleting configuration with id: " + s, e);
|
||||
|
@ -43,9 +47,9 @@ public class ConfigCrudHandler implements CrudHandler {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void getAll(@NotNull Context context) {
|
||||
public void getAll(@NotNull final Context context) {
|
||||
try {
|
||||
var allowedServers = this.serverDBService.getAllConfigs();
|
||||
final List<ServerConfig> allowedServers = this.serverDBService.getAllConfigs();
|
||||
context.json(allowedServers);
|
||||
} catch (DBException e) {
|
||||
logger.error("Getting all server configs", e);
|
||||
|
@ -54,10 +58,10 @@ public class ConfigCrudHandler implements CrudHandler {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void getOne(@NotNull Context context, @NotNull String s) {
|
||||
var id = Long.parseLong(s);
|
||||
public void getOne(@NotNull final Context context, @NotNull final String s) {
|
||||
final var id = Long.parseLong(s);
|
||||
try {
|
||||
var server = this.serverDBService.getConfigBy(id);
|
||||
final Optional<ServerConfig> server = this.serverDBService.getConfigBy(id);
|
||||
if (server.isEmpty()) {
|
||||
context.status(HttpStatus.NOT_FOUND);
|
||||
return;
|
||||
|
@ -70,9 +74,9 @@ public class ConfigCrudHandler implements CrudHandler {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void update(@NotNull Context context, @NotNull String idString) {
|
||||
var body = context.bodyAsClass(ServerConfig.NewServerConfig.class);
|
||||
var id = Long.parseLong(idString);
|
||||
public void update(@NotNull final Context context, @NotNull final String idString) {
|
||||
final ServerConfig.NewServerConfig body = context.bodyAsClass(ServerConfig.NewServerConfig.class);
|
||||
final var id = Long.parseLong(idString);
|
||||
|
||||
try {
|
||||
this.serverDBService.updateConfig(id, body);
|
||||
|
|
|
@ -10,18 +10,21 @@ import org.jetbrains.annotations.NotNull;
|
|||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class ConfigUserCrudHandler implements CrudHandler {
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
public final class ConfigUserCrudHandler implements CrudHandler {
|
||||
private static final Logger logger = LoggerFactory.getLogger(ConfigUserCrudHandler.class);
|
||||
|
||||
private final UserDBService userDBService;
|
||||
|
||||
public ConfigUserCrudHandler(UserDBService userDBService) {
|
||||
public ConfigUserCrudHandler(final UserDBService userDBService) {
|
||||
this.userDBService = userDBService;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void create(@NotNull Context context) {
|
||||
var body = context.bodyAsClass(UserConfig.NewUserConfig.class);
|
||||
public void create(@NotNull final Context context) {
|
||||
final UserConfig.NewUserConfig body = context.bodyAsClass(UserConfig.NewUserConfig.class);
|
||||
try {
|
||||
this.userDBService.addConfig(body);
|
||||
} catch (DBException e) {
|
||||
|
@ -33,9 +36,10 @@ public class ConfigUserCrudHandler implements CrudHandler {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void delete(@NotNull Context context, @NotNull String s) {
|
||||
public void delete(@NotNull final Context context, @NotNull final String s) {
|
||||
final var id = Long.parseLong(s);
|
||||
try {
|
||||
this.userDBService.deleteConfig(Long.parseLong(s));
|
||||
this.userDBService.deleteConfig(id);
|
||||
context.status(HttpStatus.NO_CONTENT);
|
||||
} catch (DBException e) {
|
||||
logger.error("Deleting configuration with id: " + s, e);
|
||||
|
@ -44,9 +48,9 @@ public class ConfigUserCrudHandler implements CrudHandler {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void getAll(@NotNull Context context) {
|
||||
public void getAll(@NotNull final Context context) {
|
||||
try {
|
||||
var allowedServers = this.userDBService.getAllConfigs();
|
||||
final List<UserConfig> allowedServers = this.userDBService.getAllConfigs();
|
||||
context.json(allowedServers);
|
||||
} catch (DBException e) {
|
||||
logger.error("Getting all user configs", e);
|
||||
|
@ -55,10 +59,10 @@ public class ConfigUserCrudHandler implements CrudHandler {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void getOne(@NotNull Context context, @NotNull String s) {
|
||||
var id = Long.parseLong(s);
|
||||
public void getOne(@NotNull final Context context, @NotNull final String s) {
|
||||
final var id = Long.parseLong(s);
|
||||
try {
|
||||
var server = this.userDBService.getConfigBy(id);
|
||||
final Optional<UserConfig> server = this.userDBService.getConfigBy(id);
|
||||
if (server.isEmpty()) {
|
||||
context.status(HttpStatus.NOT_FOUND);
|
||||
return;
|
||||
|
@ -72,8 +76,8 @@ public class ConfigUserCrudHandler implements CrudHandler {
|
|||
|
||||
@Override
|
||||
public void update(@NotNull Context context, @NotNull String idString) {
|
||||
var body = context.bodyAsClass(UserConfig.NewUserConfig.class);
|
||||
var id = Long.parseLong(idString);
|
||||
final UserConfig.NewUserConfig body = context.bodyAsClass(UserConfig.NewUserConfig.class);
|
||||
final var id = Long.parseLong(idString);
|
||||
|
||||
try {
|
||||
this.userDBService.updateConfig(id, body);
|
||||
|
|
Loading…
Reference in a new issue