Compare commits

..

No commits in common. "168a5d6c818327cbf1d08964acaee4d5646cf9c6" and "5c67a47806c298746576d7b9f3e9fdadd39c335e" have entirely different histories.

24 changed files with 562 additions and 642 deletions

View file

@ -4,47 +4,47 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.zaxxer.hikari.HikariConfig; import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource; import com.zaxxer.hikari.HikariDataSource;
import de.hhhammer.dchat.bot.openai.ChatGPTService; import de.hhhammer.dchat.bot.openai.ChatGPTService;
import de.hhhammer.dchat.db.PostgresServerDBService; import de.hhhammer.dchat.db.ServerDBService;
import de.hhhammer.dchat.db.PostgresUserDBService; import de.hhhammer.dchat.db.UserDBService;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.net.http.HttpClient; import java.net.http.HttpClient;
public final class App { public class App {
private static final Logger logger = LoggerFactory.getLogger(App.class); private static final Logger logger = LoggerFactory.getLogger(App.class);
public static void main(final String[] args) { public static void main(String[] args) {
final String discordApiKey = System.getenv("DISCORD_API_KEY"); String discordApiKey = System.getenv("DISCORD_API_KEY");
if (discordApiKey == null) { if (discordApiKey == null) {
logger.error("Missing environment variables: DISCORD_API_KEY"); logger.error("Missing environment variables: DISCORD_API_KEY");
System.exit(1); System.exit(1);
} }
final String openaiApiKey = System.getenv("OPENAI_API_KEY"); String openaiApiKey = System.getenv("OPENAI_API_KEY");
if (openaiApiKey == null) { if (openaiApiKey == null) {
logger.error("Missing environment variables: OPENAI_API_KEY"); logger.error("Missing environment variables: OPENAI_API_KEY");
System.exit(1); System.exit(1);
} }
final String postgresUser = System.getenv("POSTGRES_USER"); String postgresUser = System.getenv("POSTGRES_USER");
final String postgresPassword = System.getenv("POSTGRES_PASSWORD"); String postgresPassword = System.getenv("POSTGRES_PASSWORD");
final String postgresUrl = System.getenv("POSTGRES_URL"); String postgresUrl = System.getenv("POSTGRES_URL");
if (postgresUser == null || postgresPassword == null || postgresUrl == null) { if (postgresUser == null || postgresPassword == null || postgresUrl == null) {
logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL"); logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL");
System.exit(1); System.exit(1);
} }
final var chatGPTService = new ChatGPTService(openaiApiKey, HttpClient.newHttpClient(), new ObjectMapper()); var chatGPTService = new ChatGPTService(openaiApiKey, HttpClient.newHttpClient(), new ObjectMapper());
final var config = new HikariConfig(); var config = new HikariConfig();
config.setJdbcUrl(postgresUrl); config.setJdbcUrl(postgresUrl);
config.setUsername(postgresUser); config.setUsername(postgresUser);
config.setPassword(postgresPassword); config.setPassword(postgresPassword);
try (var ds = new HikariDataSource(config)) { try (var ds = new HikariDataSource(config)) {
final var serverDBService = new PostgresServerDBService(ds); var serverDBService = new ServerDBService(ds);
final var userDBService = new PostgresUserDBService(ds); var userDBService = new UserDBService(ds);
final var discordBot = new DiscordBot(serverDBService, userDBService, chatGPTService, discordApiKey); var discordBot = new DiscordBot(serverDBService, userDBService, chatGPTService, discordApiKey);
discordBot.run(); discordBot.run();
} }
} }

View file

@ -9,7 +9,6 @@ import de.hhhammer.dchat.db.UserDBService;
import org.javacord.api.DiscordApi; import org.javacord.api.DiscordApi;
import org.javacord.api.DiscordApiBuilder; import org.javacord.api.DiscordApiBuilder;
import org.javacord.api.interaction.SlashCommand; import org.javacord.api.interaction.SlashCommand;
import org.javacord.api.interaction.SlashCommandInteraction;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -17,7 +16,7 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
public final class DiscordBot implements Runnable { public class DiscordBot implements Runnable {
private static final Logger logger = LoggerFactory.getLogger(DiscordBot.class); private static final Logger logger = LoggerFactory.getLogger(DiscordBot.class);
private final ServerDBService serverDBService; private final ServerDBService serverDBService;
@ -25,7 +24,7 @@ public final class DiscordBot implements Runnable {
private final ChatGPTService chatGPTService; private final ChatGPTService chatGPTService;
private final String discordApiKey; private final String discordApiKey;
public DiscordBot(final ServerDBService serverDBService, final UserDBService userDBService, final ChatGPTService chatGPTService, final String discordApiKey) { public DiscordBot(ServerDBService serverDBService, UserDBService userDBService, ChatGPTService chatGPTService, String discordApiKey) {
this.serverDBService = serverDBService; this.serverDBService = serverDBService;
this.userDBService = userDBService; this.userDBService = userDBService;
this.chatGPTService = chatGPTService; this.chatGPTService = chatGPTService;
@ -35,29 +34,29 @@ public final class DiscordBot implements Runnable {
@Override @Override
public void run() { public void run() {
logger.info("Starting Discord application"); logger.info("Starting Discord application");
final DiscordApi discordApi = new DiscordApiBuilder() DiscordApi discordApi = new DiscordApiBuilder()
.setToken(discordApiKey) .setToken(discordApiKey)
.login() .login()
.join(); .join();
discordApi.setMessageCacheSize(10, 60*60); discordApi.setMessageCacheSize(10, 60*60);
final var future = new CompletableFuture<Void>(); var future = new CompletableFuture<Void>();
Runtime.getRuntime().addShutdownHook(Thread.ofVirtual().unstarted(() -> { Runtime.getRuntime().addShutdownHook(Thread.ofVirtual().unstarted(() -> {
logger.info("Shutting down Discord application"); logger.info("Shutting down Discord application");
discordApi.disconnect().thenAccept(future::complete); discordApi.disconnect().thenAccept(future::complete);
})); }));
final SlashCommand token = SlashCommand.with("tokens", "Check how many tokens where spend on this server") var token = SlashCommand.with("tokens", "Check how many tokens where spend on this server")
.createGlobal(discordApi) .createGlobal(discordApi)
.join(); .join();
discordApi.addSlashCommandCreateListener(event -> { discordApi.addSlashCommandCreateListener(event -> {
logger.debug("Event? " + event.getSlashCommandInteraction().getFullCommandName()); logger.debug("Event? " + event.getSlashCommandInteraction().getFullCommandName());
final SlashCommandInteraction command = event.getSlashCommandInteraction(); var command = event.getSlashCommandInteraction();
if (token.getFullCommandNames().contains(command.getFullCommandName())) { if (token.getFullCommandNames().contains(command.getFullCommandName())) {
event.getInteraction() event.getInteraction()
.respondLater() .respondLater()
.orTimeout(30, TimeUnit.SECONDS) .orTimeout(30, TimeUnit.SECONDS)
.thenAccept((interactionOriginalResponseUpdater) -> { .thenAccept((interactionOriginalResponseUpdater) -> {
final long tokens = event.getInteraction().getServer().isPresent() ? var tokens = event.getInteraction().getServer().isPresent() ?
this.serverDBService.tokensOfLast30Days(String.valueOf(event.getInteraction().getServer().get().getId())) : this.serverDBService.tokensOfLast30Days(String.valueOf(event.getInteraction().getServer().get().getId())) :
this.userDBService.tokensOfLast30Days(String.valueOf(event.getInteraction().getUser().getId())); this.userDBService.tokensOfLast30Days(String.valueOf(event.getInteraction().getUser().getId()));
interactionOriginalResponseUpdater.setContent("" + tokens).update(); interactionOriginalResponseUpdater.setContent("" + tokens).update();

View file

@ -9,17 +9,17 @@ import org.slf4j.LoggerFactory;
import java.io.IOException; import java.io.IOException;
public final class MessageCreateHandler implements MessageCreateListener { public class MessageCreateHandler implements MessageCreateListener {
private static final Logger logger = LoggerFactory.getLogger(MessageCreateHandler.class); private static final Logger logger = LoggerFactory.getLogger(MessageCreateHandler.class);
private final MessageHandler messageHandler; private final MessageHandler messageHandler;
public MessageCreateHandler(final MessageHandler messageHandler) { public MessageCreateHandler(MessageHandler messageHandler) {
this.messageHandler = messageHandler; this.messageHandler = messageHandler;
} }
@Override @Override
public void onMessageCreate(final MessageCreateEvent event) { public void onMessageCreate(MessageCreateEvent event) {
Thread.ofVirtual().start(() -> { Thread.ofVirtual().start(() -> {
if (!event.canYouReadContent() || event.getMessageAuthor().isBotUser() || !isNormalOrReplyMessageType(event)) { if (!event.canYouReadContent() || event.getMessageAuthor().isBotUser() || !isNormalOrReplyMessageType(event)) {
return; return;
@ -36,15 +36,15 @@ public final class MessageCreateHandler implements MessageCreateListener {
} }
try { try {
this.messageHandler.handle(event); this.messageHandler.handle(event);
} catch (final ResponseException | IOException | InterruptedException e) { } catch (ResponseException | IOException | InterruptedException e) {
logger.error("Reading a message from the listener", e); logger.error("Reading a message from the listener", e);
event.getMessage().reply("Sorry but something went wrong :("); event.getMessage().reply("Sorry but something went wrong :(");
} }
}); });
} }
private boolean isNormalOrReplyMessageType(final MessageCreateEvent event) { private boolean isNormalOrReplyMessageType(MessageCreateEvent event) {
final MessageType type = event.getMessage().getType(); MessageType type = event.getMessage().getType();
return type == MessageType.NORMAL || type == MessageType.REPLY; return type == MessageType.NORMAL || type == MessageType.REPLY;
} }
} }

View file

@ -4,10 +4,7 @@ import de.hhhammer.dchat.bot.openai.ChatGPTRequestBuilder;
import de.hhhammer.dchat.bot.openai.ChatGPTService; import de.hhhammer.dchat.bot.openai.ChatGPTService;
import de.hhhammer.dchat.bot.openai.MessageContext.ReplyInteraction; import de.hhhammer.dchat.bot.openai.MessageContext.ReplyInteraction;
import de.hhhammer.dchat.bot.openai.ResponseException; import de.hhhammer.dchat.bot.openai.ResponseException;
import de.hhhammer.dchat.bot.openai.models.ChatGPTRequest;
import de.hhhammer.dchat.bot.openai.models.ChatGPTResponse;
import de.hhhammer.dchat.db.ServerDBService; import de.hhhammer.dchat.db.ServerDBService;
import de.hhhammer.dchat.db.models.server.ServerConfig;
import de.hhhammer.dchat.db.models.server.ServerMessage; import de.hhhammer.dchat.db.models.server.ServerMessage;
import org.javacord.api.entity.DiscordEntity; import org.javacord.api.entity.DiscordEntity;
import org.javacord.api.entity.message.Message; import org.javacord.api.entity.message.Message;
@ -20,44 +17,43 @@ import org.slf4j.LoggerFactory;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Optional;
public final class ServerMessageHandler implements MessageHandler { public class ServerMessageHandler implements MessageHandler {
private static final Logger logger = LoggerFactory.getLogger(ServerMessageHandler.class); private static final Logger logger = LoggerFactory.getLogger(ServerMessageHandler.class);
private final ServerDBService serverDBService; private final ServerDBService serverDBService;
private final ChatGPTService chatGPTService; private final ChatGPTService chatGPTService;
public ServerMessageHandler(final ServerDBService serverDBService, final ChatGPTService chatGPTService) { public ServerMessageHandler(ServerDBService serverDBService, ChatGPTService chatGPTService) {
this.serverDBService = serverDBService; this.serverDBService = serverDBService;
this.chatGPTService = chatGPTService; this.chatGPTService = chatGPTService;
} }
@Override @Override
public void handle(final MessageCreateEvent event) throws ResponseException, IOException, InterruptedException { public void handle(MessageCreateEvent event) throws ResponseException, IOException, InterruptedException {
final String content = extractContent(event); String content = extractContent(event);
final long serverId = event.getServer().get().getId(); var serverId = event.getServer().get().getId();
final String systemMessage = this.serverDBService.getConfig(String.valueOf(serverId)).get().systemMessage(); var systemMessage = this.serverDBService.getConfig(String.valueOf(serverId)).get().systemMessage();
final List<ReplyInteraction> messageContext = event.getMessage().getType() == MessageType.REPLY ? getContextMessages(event) : List.of(); List<ReplyInteraction> messageContext = event.getMessage().getType() == MessageType.REPLY ? getContextMessages(event) : List.of();
final ChatGPTRequest request = new ChatGPTRequestBuilder().contextRequest(messageContext, content, systemMessage); var request = new ChatGPTRequestBuilder().contextRequest(messageContext, content, systemMessage);
final ChatGPTResponse response = this.chatGPTService.submit(request); var response = this.chatGPTService.submit(request);
if (response.choices().isEmpty()) { if (response.choices().isEmpty()) {
event.getMessage().reply("No response available"); event.getMessage().reply("No response available");
return; return;
} }
final String answer = response.choices().get(0).message().content(); var answer = response.choices().get(0).message().content();
logServerMessage(event, response.usage().totalTokens()); logServerMessage(event, response.usage().totalTokens());
event.getMessage().reply(answer); event.getMessage().reply(answer);
} }
@Override @Override
public boolean isAllowed(final MessageCreateEvent event) { public boolean isAllowed(MessageCreateEvent event) {
if (event.getServer().isEmpty()) { if (event.getServer().isEmpty()) {
return false; return false;
} }
final long serverId = event.getServer().get().getId(); var serverId = event.getServer().get().getId();
final Optional<ServerConfig> config = this.serverDBService.getConfig(String.valueOf(serverId)); var config = this.serverDBService.getConfig(String.valueOf(serverId));
if (config.isEmpty()) { if (config.isEmpty()) {
logger.debug("Not allowed with id: " + serverId); logger.debug("Not allowed with id: " + serverId);
return false; return false;
@ -66,39 +62,39 @@ public final class ServerMessageHandler implements MessageHandler {
} }
@Override @Override
public boolean exceedsRate(final MessageCreateEvent event) { public boolean exceedsRate(MessageCreateEvent event) {
final String serverId = String.valueOf(event.getServer().get().getId()); var serverId = String.valueOf(event.getServer().get().getId());
final Optional<ServerConfig> config = this.serverDBService.getConfig(serverId); var config = this.serverDBService.getConfig(serverId);
if (config.isEmpty()) { if (config.isEmpty()) {
logger.error("Missing configuration for server with id: " + serverId); logger.error("Missing configuration for server with id: " + serverId);
return true; return true;
} }
final int rateLimit = config.get().rateLimit(); var rateLimit = config.get().rateLimit();
final int countMessagesInLastMinute = this.serverDBService.countMessagesInLastMinute(serverId); var countMessagesInLastMinute = this.serverDBService.countMessagesInLastMinute(serverId);
return countMessagesInLastMinute >= rateLimit; return countMessagesInLastMinute >= rateLimit;
} }
@Override @Override
public boolean canHandle(final MessageCreateEvent event) { public boolean canHandle(MessageCreateEvent event) {
return event.isServerMessage(); return event.isServerMessage();
} }
private void logServerMessage(final MessageCreateEvent event, final int tokens) { private void logServerMessage(MessageCreateEvent event, int tokens) {
final long serverId = event.getServer().map(DiscordEntity::getId).get(); var serverId = event.getServer().map(DiscordEntity::getId).get();
final long userId = event.getMessageAuthor().getId(); var userId = event.getMessageAuthor().getId();
final var serverMessage = new ServerMessage.NewServerMessage(String.valueOf(serverId), userId, tokens); var serverMessage = new ServerMessage.NewServerMessage(String.valueOf(serverId), userId, tokens);
this.serverDBService.addMessage(serverMessage); this.serverDBService.addMessage(serverMessage);
} }
private String extractContent(final MessageCreateEvent event) { private String extractContent(MessageCreateEvent event) {
final long ownId = event.getApi().getYourself().getId(); long ownId = event.getApi().getYourself().getId();
return event.getMessageContent().replaceFirst("<" + ownId + "> ", ""); return event.getMessageContent().replaceFirst("<" + ownId + "> ", "");
} }
@NotNull @NotNull
private List<ReplyInteraction> getContextMessages(final MessageCreateEvent event) { private List<ReplyInteraction> getContextMessages(MessageCreateEvent event) {
return event.getMessage() return event.getMessage()
.getMessageReference() .getMessageReference()
.map(MessageReference::getMessage) .map(MessageReference::getMessage)

View file

@ -4,10 +4,7 @@ import de.hhhammer.dchat.bot.openai.ChatGPTRequestBuilder;
import de.hhhammer.dchat.bot.openai.ChatGPTService; import de.hhhammer.dchat.bot.openai.ChatGPTService;
import de.hhhammer.dchat.bot.openai.MessageContext.PreviousInteraction; import de.hhhammer.dchat.bot.openai.MessageContext.PreviousInteraction;
import de.hhhammer.dchat.bot.openai.ResponseException; import de.hhhammer.dchat.bot.openai.ResponseException;
import de.hhhammer.dchat.bot.openai.models.ChatGPTRequest;
import de.hhhammer.dchat.bot.openai.models.ChatGPTResponse;
import de.hhhammer.dchat.db.UserDBService; import de.hhhammer.dchat.db.UserDBService;
import de.hhhammer.dchat.db.models.user.UserConfig;
import de.hhhammer.dchat.db.models.user.UserMessage; import de.hhhammer.dchat.db.models.user.UserMessage;
import org.javacord.api.event.message.MessageCreateEvent; import org.javacord.api.event.message.MessageCreateEvent;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -15,47 +12,46 @@ import org.slf4j.LoggerFactory;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Optional;
public final class UserMessageHandler implements MessageHandler { public class UserMessageHandler implements MessageHandler {
private static final Logger logger = LoggerFactory.getLogger(UserMessageHandler.class); private static final Logger logger = LoggerFactory.getLogger(UserMessageHandler.class);
private final UserDBService userDBService; private final UserDBService userDBService;
private final ChatGPTService chatGPTService; private final ChatGPTService chatGPTService;
public UserMessageHandler(final UserDBService userDBService, final ChatGPTService chatGPTService) { public UserMessageHandler(UserDBService userDBService, ChatGPTService chatGPTService) {
this.userDBService = userDBService; this.userDBService = userDBService;
this.chatGPTService = chatGPTService; this.chatGPTService = chatGPTService;
} }
@Override @Override
public void handle(final MessageCreateEvent event) throws ResponseException, IOException, InterruptedException { public void handle(MessageCreateEvent event) throws ResponseException, IOException, InterruptedException {
final String content = event.getReadableMessageContent(); String content = event.getReadableMessageContent();
final String userId = String.valueOf(event.getMessageAuthor().getId()); var userId = String.valueOf(event.getMessageAuthor().getId());
final UserConfig config = this.userDBService.getConfig(userId).get(); var config = this.userDBService.getConfig(userId).get();
final String systemMessage = config.systemMessage(); var systemMessage = config.systemMessage();
final List<PreviousInteraction> context = this.userDBService.getLastMessages(userId, config.contextLength()) List<PreviousInteraction> context = this.userDBService.getLastMessages(userId, config.contextLength())
.stream() .stream()
.map(userMessage -> new PreviousInteraction(userMessage.question(), userMessage.answer())) .map(userMessage -> new PreviousInteraction(userMessage.question(), userMessage.answer()))
.toList(); .toList();
final ChatGPTRequest request = new ChatGPTRequestBuilder().contextRequest(context, content, systemMessage); var request = new ChatGPTRequestBuilder().contextRequest(context, content, systemMessage);
final ChatGPTResponse response = this.chatGPTService.submit(request); var response = this.chatGPTService.submit(request);
if (response.choices().isEmpty()) { if (response.choices().isEmpty()) {
event.getMessage().reply("No response available"); event.getMessage().reply("No response available");
return; return;
} }
final String answer = response.choices().get(0).message().content(); var answer = response.choices().get(0).message().content();
logUserMessage(event, content, answer, response.usage().totalTokens()); logUserMessage(event, content, answer, response.usage().totalTokens());
event.getChannel().sendMessage(answer); event.getChannel().sendMessage(answer);
} }
@Override @Override
public boolean isAllowed(final MessageCreateEvent event) { public boolean isAllowed(MessageCreateEvent event) {
if (event.getServer().isPresent()) { if (event.getServer().isPresent()) {
return false; return false;
} }
final long userId = event.getMessageAuthor().getId(); var userId = event.getMessageAuthor().getId();
final Optional<UserConfig> config = this.userDBService.getConfig(String.valueOf(userId)); var config = this.userDBService.getConfig(String.valueOf(userId));
if (config.isEmpty()) { if (config.isEmpty()) {
logger.debug("Not allowed with id: " + userId); logger.debug("Not allowed with id: " + userId);
return false; return false;
@ -64,28 +60,28 @@ public final class UserMessageHandler implements MessageHandler {
} }
@Override @Override
public boolean exceedsRate(final MessageCreateEvent event) { public boolean exceedsRate(MessageCreateEvent event) {
final String userId = String.valueOf(event.getMessageAuthor().getId()); var userId = String.valueOf(event.getMessageAuthor().getId());
final Optional<UserConfig> config = this.userDBService.getConfig(userId); var config = this.userDBService.getConfig(userId);
if (config.isEmpty()) { if (config.isEmpty()) {
logger.error("Missing configuration for userId with id: " + userId); logger.error("Missing configuration for userId with id: " + userId);
return true; return true;
} }
final int rateLimit = config.get().rateLimit(); var rateLimit = config.get().rateLimit();
final int countMessagesInLastMinute = this.userDBService.countMessagesInLastMinute(userId); var countMessagesInLastMinute = this.userDBService.countMessagesInLastMinute(userId);
return countMessagesInLastMinute >= rateLimit; return countMessagesInLastMinute >= rateLimit;
} }
@Override @Override
public boolean canHandle(final MessageCreateEvent event) { public boolean canHandle(MessageCreateEvent event) {
return event.isPrivateMessage(); return event.isPrivateMessage();
} }
private void logUserMessage(final MessageCreateEvent event, final String question, final String answer, final int tokens) { private void logUserMessage(MessageCreateEvent event, String question, String answer, int tokens) {
final long userId = event.getMessageAuthor().getId(); var userId = event.getMessageAuthor().getId();
final UserMessage.NewUserMessage userMessage = new UserMessage.NewUserMessage(String.valueOf(userId), question, answer, tokens); var userMessage = new UserMessage.NewUserMessage(String.valueOf(userId), question, answer, tokens);
this.userDBService.addMessage(userMessage); this.userDBService.addMessage(userMessage);
} }
} }

View file

@ -8,19 +8,29 @@ import org.jetbrains.annotations.NotNull;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
public final class ChatGPTRequestBuilder { public class ChatGPTRequestBuilder {
private static final String smallContextModel = "gpt-3.5-turbo"; private static final String smallContextModel = "gpt-3.5-turbo";
private static final String bigContextModel = "gpt-3.5-turbo-16k"; private static final String bigContextModel = "gpt-3.5-turbo-16k";
public ChatGPTRequestBuilder() { public ChatGPTRequestBuilder() {
} }
public ChatGPTRequest contextRequest(List<? extends MessageContext> contextMessages, String message, String systemMessage) {
List<ChatGPTRequest.Message> messages = getMessages(contextMessages, message, systemMessage);
String contextModel = contextMessages.size() <= 1 ? smallContextModel : bigContextModel;
return new ChatGPTRequest(
contextModel,
messages,
0.7f
);
}
@NotNull @NotNull
private static List<ChatGPTRequest.Message> getMessages(final List<? extends MessageContext> contextMessages, final String message, final String systemMessage) { private static List<ChatGPTRequest.Message> getMessages(List<? extends MessageContext> contextMessages, String message, String systemMessage) {
final ChatGPTRequest.Message systemMsg = new ChatGPTRequest.Message("system", systemMessage); var systemMsg = new ChatGPTRequest.Message("system", systemMessage);
final List<ChatGPTRequest.Message> contextMsgs = getContextMessages(contextMessages); List<ChatGPTRequest.Message> contextMsgs = getContextMessages(contextMessages);
final ChatGPTRequest.Message userMessage = new ChatGPTRequest.Message("user", message); var userMessage = new ChatGPTRequest.Message("user", message);
final List<ChatGPTRequest.Message> messages = new ArrayList<>(); List<ChatGPTRequest.Message> messages = new ArrayList<>();
messages.add(systemMsg); messages.add(systemMsg);
messages.addAll(contextMsgs); messages.addAll(contextMsgs);
messages.add(userMessage); messages.add(userMessage);
@ -28,7 +38,7 @@ public final class ChatGPTRequestBuilder {
} }
@NotNull @NotNull
private static List<ChatGPTRequest.Message> getContextMessages(final List<? extends MessageContext> contextMessages) { private static List<ChatGPTRequest.Message> getContextMessages(List<? extends MessageContext> contextMessages) {
return contextMessages.stream() return contextMessages.stream()
.map(ChatGPTRequestBuilder::mapContextMessages) .map(ChatGPTRequestBuilder::mapContextMessages)
.flatMap(List::stream) .flatMap(List::stream)
@ -36,7 +46,7 @@ public final class ChatGPTRequestBuilder {
} }
@NotNull @NotNull
private static List<ChatGPTRequest.Message> mapContextMessages(final MessageContext contextMessage) { private static List<ChatGPTRequest.Message> mapContextMessages(MessageContext contextMessage) {
return switch (contextMessage) { return switch (contextMessage) {
case PreviousInteraction previousInteractions -> List.of( case PreviousInteraction previousInteractions -> List.of(
new ChatGPTRequest.Message("user", previousInteractions.question()), new ChatGPTRequest.Message("user", previousInteractions.question()),
@ -46,14 +56,4 @@ public final class ChatGPTRequestBuilder {
List.of(new ChatGPTRequest.Message("assistant", replyInteractions.answer())); List.of(new ChatGPTRequest.Message("assistant", replyInteractions.answer()));
}; };
} }
public ChatGPTRequest contextRequest(final List<? extends MessageContext> contextMessages, final String message, final String systemMessage) {
final List<ChatGPTRequest.Message> messages = getMessages(contextMessages, message, systemMessage);
final String contextModel = contextMessages.size() <= 1 ? smallContextModel : bigContextModel;
return new ChatGPTRequest(
contextModel,
messages,
0.7f
);
}
} }

View file

@ -5,35 +5,34 @@ import de.hhhammer.dchat.bot.openai.models.ChatGPTRequest;
import de.hhhammer.dchat.bot.openai.models.ChatGPTResponse; import de.hhhammer.dchat.bot.openai.models.ChatGPTResponse;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.net.URI; import java.net.URI;
import java.net.http.HttpClient; import java.net.http.HttpClient;
import java.net.http.HttpRequest; import java.net.http.HttpRequest;
import java.net.http.HttpResponse; import java.net.http.HttpResponse;
import java.time.Duration; import java.time.Duration;
public final class ChatGPTService { public class ChatGPTService {
private static final String url = "https://api.openai.com/v1/chat/completions"; private static final String url = "https://api.openai.com/v1/chat/completions";
private final String apiKey; private final String apiKey;
private final HttpClient httpClient; private final HttpClient httpClient;
private final ObjectMapper mapper; private final ObjectMapper mapper;
public ChatGPTService(final String apiKey, final HttpClient httpClient, final ObjectMapper mapper) { public ChatGPTService(String apiKey, HttpClient httpClient, ObjectMapper mapper) {
this.apiKey = apiKey; this.apiKey = apiKey;
this.httpClient = httpClient; this.httpClient = httpClient;
this.mapper = mapper; this.mapper = mapper;
} }
public ChatGPTResponse submit(final ChatGPTRequest chatGPTRequest) throws IOException, InterruptedException, ResponseException { public ChatGPTResponse submit(ChatGPTRequest chatGPTRequest) throws IOException, InterruptedException, ResponseException {
final byte[] data = mapper.writeValueAsBytes(chatGPTRequest); var data = mapper.writeValueAsBytes(chatGPTRequest);
final HttpRequest request = HttpRequest.newBuilder(URI.create(url)) var request = HttpRequest.newBuilder(URI.create(url))
.POST(HttpRequest.BodyPublishers.ofByteArray(data)) .POST(HttpRequest.BodyPublishers.ofByteArray(data))
.setHeader("Content-Type", "application/json") .setHeader("Content-Type", "application/json")
.setHeader("Authorization", "Bearer " + this.apiKey) .setHeader("Authorization", "Bearer " + this.apiKey)
.timeout(Duration.ofMinutes(5)) .timeout(Duration.ofMinutes(5))
.build(); .build();
final HttpResponse<InputStream> responseStream = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); var responseStream = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
if (responseStream.statusCode() != 200) { if (responseStream.statusCode() != 200) {
throw new ResponseException("Response status code was not 200: " + responseStream.statusCode()); throw new ResponseException("Response status code was not 200: " + responseStream.statusCode());
} }

View file

@ -1,7 +1,7 @@
package de.hhhammer.dchat.bot.openai; package de.hhhammer.dchat.bot.openai;
public final class ResponseException extends Exception { public class ResponseException extends Exception {
public ResponseException(final String message) { public ResponseException(String message) {
super(message); super(message);
} }
} }

View file

@ -1,7 +1,7 @@
package de.hhhammer.dchat.db; package de.hhhammer.dchat.db;
public final class DBException extends Exception { public class DBException extends Exception {
public DBException(final String message, final Throwable cause) { public DBException(String message, Throwable cause) {
super(message, cause); super(message, cause);
} }
} }

View file

@ -1,202 +0,0 @@
package de.hhhammer.dchat.db;
import de.hhhammer.dchat.db.models.server.ServerConfig;
import de.hhhammer.dchat.db.models.server.ServerMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.sql.DataSource;
import java.sql.*;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Optional;
import java.util.stream.StreamSupport;
public final class PostgresServerDBService implements ServerDBService {
private static final Logger logger = LoggerFactory.getLogger(PostgresServerDBService.class);
private final DataSource dataSource;
public PostgresServerDBService(final DataSource dataSource) {
this.dataSource = dataSource;
}
@Override
public Optional<ServerConfig> getConfig(final String serverId) {
final var getServerConfig = """
SELECT * FROM server_configs WHERE server_id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, serverId);
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (final SQLException e) {
logger.error("Getting configuration for server with id: " + serverId, e);
} catch (final ResultSetIteratorException e) {
logger.error("Iterating over ServerConfig ResultSet for server with id: " + serverId, e);
}
return Optional.empty();
}
@Override
public List<ServerConfig> getAllConfigs() throws DBException {
final var getAllowedServerSql = """
SELECT * FROM server_configs
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getAllowedServerSql)
) {
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).toList();
} catch (final SQLException e) {
throw new DBException("Loading all configs", e);
} catch (final ResultSetIteratorException e) {
throw new DBException("Iterating over configs", e);
}
}
@Override
public Optional<ServerConfig> getConfigBy(final long id) throws DBException {
final var getServerConfig = """
SELECT * FROM server_configs WHERE id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (SQLException e) {
throw new DBException("Getting configuration with id: " + id, e);
} catch (ResultSetIteratorException e) {
throw new DBException("Iterating over ServerConfig ResultSet for id: " + id, e);
}
}
@Override
public void addConfig(final ServerConfig.NewServerConfig newServerConfig) throws DBException {
final var getServerConfig = """
INSERT INTO server_configs (server_id, system_message, rate_limit) VALUES (?,?,?)
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newServerConfig.serverId());
pstmt.setString(2, newServerConfig.systemMessage());
pstmt.setInt(3, newServerConfig.rateLimit());
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config added for server with id: " + newServerConfig.serverId());
}
} catch (final SQLException e) {
throw new DBException("Adding configuration to server with id: " + newServerConfig.serverId(), e);
}
}
@Override
public void updateConfig(final long id, final ServerConfig.NewServerConfig newServerConfig) throws DBException {
final var getServerConfig = """
UPDATE server_configs SET system_message = ?, rate_limit = ?, server_id = ? WHERE id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newServerConfig.systemMessage());
pstmt.setInt(2, newServerConfig.rateLimit());
pstmt.setString(3, newServerConfig.serverId());
pstmt.setLong(4, id);
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config update for server with id: " + newServerConfig.serverId());
}
} catch (final SQLException e) {
throw new DBException("Adding configuration to server with id: " + newServerConfig.serverId(), e);
}
}
@Override
public void deleteConfig(final long id) throws DBException {
final var getServerConfig = """
DELETE FROM server_configs WHERE id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config deleted for server with id: " + id);
}
} catch (final SQLException e) {
throw new DBException("Deleting configuration for server with id: " + id, e);
}
}
@Override
public int countMessagesInLastMinute(final String serverId) {
final var getServerConfig = """
SELECT count(*) FROM server_messages WHERE server_id = ? AND time <= ? and time >= ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, serverId);
final var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(1, ChronoUnit.MINUTES)));
final ResultSet resultSet = pstmt.executeQuery();
if (resultSet.next()) return resultSet.getInt(1);
} catch (final SQLException e) {
logger.error("Getting messages for server with id: " + serverId, e);
} catch (final ResultSetIteratorException e) {
logger.error("Iterating over ServerMessages ResultSet for server with id: " + serverId, e);
}
return Integer.MAX_VALUE;
}
@Override
public void addMessage(final ServerMessage.NewServerMessage serverMessage) {
final var getServerConfig = """
INSERT INTO server_messages (server_id, user_id, tokens) VALUES (?,?,?)
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, serverMessage.serverId());
pstmt.setLong(2, serverMessage.userId());
pstmt.setInt(3, serverMessage.tokens());
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No message added for server with id: " + serverMessage.serverId());
}
} catch (final SQLException e) {
logger.error("Adding message to server with id: " + serverMessage.serverId(), e);
}
}
@Override
public long tokensOfLast30Days(final String serverId) {
final var countTokensOfLast30Days = """
SELECT sum(tokens) FROM server_messages WHERE server_id = ? AND time < ? AND time >= ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
) {
pstmt.setString(1, serverId);
final var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(30, ChronoUnit.DAYS)));
final ResultSet resultSet = pstmt.executeQuery();
if (resultSet.next()) return resultSet.getLong(1);
} catch (final SQLException e) {
logger.error("Counting tokens of the last 30 days from server with id: " + serverId, e);
}
logger.error("No tokens found for server with id: " + serverId);
return 0;
}
}

View file

@ -1,226 +0,0 @@
package de.hhhammer.dchat.db;
import de.hhhammer.dchat.db.models.user.UserConfig;
import de.hhhammer.dchat.db.models.user.UserMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.sql.DataSource;
import java.sql.*;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Optional;
import java.util.stream.StreamSupport;
public final class PostgresUserDBService implements UserDBService {
private static final Logger logger = LoggerFactory.getLogger(PostgresUserDBService.class);
private final DataSource dataSource;
public PostgresUserDBService(final DataSource dataSource) {
this.dataSource = dataSource;
}
@Override
public Optional<UserConfig> getConfig(final String userId) {
final var getServerConfig = """
SELECT * FROM user_configs WHERE user_id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, userId);
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (final SQLException e) {
logger.error("Getting configuration for user with id: " + userId, e);
} catch (final ResultSetIteratorException e) {
logger.error("Iterating over ServerConfig ResultSet for user with id: " + userId, e);
}
return Optional.empty();
}
@Override
public Optional<UserConfig> getConfigBy(final long id) throws DBException {
final var getServerConfig = """
SELECT * FROM user_configs WHERE id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (final SQLException e) {
throw new DBException("Getting configuration id: " + id, e);
} catch (final ResultSetIteratorException e) {
throw new DBException("Iterating over UserConfig ResultSet with id: " + id, e);
}
}
@Override
public List<UserConfig> getAllConfigs() throws DBException {
final var getServerConfig = """
SELECT * FROM user_configs
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).toList();
} catch (final SQLException e) {
throw new DBException("Getting all configurations", e);
} catch (final ResultSetIteratorException e) {
throw new DBException("Iterating over all UserConfig ResultSet", e);
}
}
@Override
public void addConfig(UserConfig.NewUserConfig newUserConfig) throws DBException {
final var getServerConfig = """
INSERT INTO user_configs (user_id, system_message, context_length, rate_limit) VALUES (?,?,?,?)
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserConfig.userId());
pstmt.setString(2, newUserConfig.systemMessage());
pstmt.setInt(3, newUserConfig.contextLength());
pstmt.setInt(4, newUserConfig.rateLimit());
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config added for user with id: " + newUserConfig.userId());
}
} catch (final SQLException e) {
throw new DBException("Adding configuration for user with id: " + newUserConfig.userId(), e);
}
}
@Override
public void updateConfig(final long id, final UserConfig.NewUserConfig newUserConfig) throws DBException {
final var getServerConfig = """
UPDATE user_configs SET system_message = ?, context_length = ?, rate_limit = ?, user_id = ? WHERE id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserConfig.systemMessage());
pstmt.setInt(2, newUserConfig.rateLimit());
pstmt.setLong(3, newUserConfig.contextLength());
pstmt.setString(4, newUserConfig.userId());
pstmt.setLong(5, id);
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config update with id: " + id);
}
} catch (final SQLException e) {
throw new DBException("Updating configuration with id: " + id, e);
}
}
@Override
public void deleteConfig(final long id) throws DBException {
final var getServerConfig = """
DELETE FROM user_configs WHERE id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config deleted for user with id: " + id);
}
} catch (final SQLException e) {
throw new DBException("Deleting configuration with id: " + id, e);
}
}
@Override
public int countMessagesInLastMinute(final String userId) {
final var getServerConfig = """
SELECT count(*) FROM user_messages WHERE user_id = ? AND time <= ? and time >= ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, userId);
final var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(1, ChronoUnit.MINUTES)));
final ResultSet resultSet = pstmt.executeQuery();
if (resultSet.next()) return resultSet.getInt(1);
} catch (final SQLException e) {
logger.error("Getting messages for user with id: " + userId, e);
} catch (final ResultSetIteratorException e) {
logger.error("Iterating over ServerMessages ResultSet for user with id: " + userId, e);
}
return Integer.MAX_VALUE;
}
@Override
public void addMessage(final UserMessage.NewUserMessage newUserMessage) {
final var getServerConfig = """
INSERT INTO user_messages (user_id, question, answer, tokens) VALUES (?,?,?,?)
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserMessage.userId());
pstmt.setString(2, newUserMessage.question());
pstmt.setString(3, newUserMessage.answer());
pstmt.setInt(4, newUserMessage.tokens());
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No message added for user with id: " + newUserMessage.userId());
}
} catch (final SQLException e) {
logger.error("Adding message to user with id: " + newUserMessage.userId(), e);
}
}
@Override
public List<UserMessage> getLastMessages(final String userId, final int limit) {
final var getLastMessages = """
SELECT * FROM user_messages WHERE user_id = ? ORDER BY time DESC LIMIT ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getLastMessages)
) {
pstmt.setString(1, userId);
pstmt.setInt(2, limit);
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<UserMessage> iterable = () -> new ResultSetIterator<>(resultSet, new UserMessage.UserMessageResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).toList();
} catch (final SQLException e) {
logger.error("Fetching last messages for user whit id: " + userId, e);
} catch (final ResultSetIteratorException e) {
logger.error("Iterating over messages ResultSet from user with id: " + userId, e);
}
return List.of();
}
@Override
public long tokensOfLast30Days(final String userId) {
final var countTokensOfLast30Days = """
SELECT sum(tokens) FROM user_messages WHERE user_id = ? AND time < ? AND time >= ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
) {
pstmt.setString(1, userId);
final var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(30, ChronoUnit.DAYS)));
final ResultSet resultSet = pstmt.executeQuery();
if (resultSet.next()) return resultSet.getLong(1);
} catch (final SQLException e) {
logger.error("Counting tokens of the last 30 days from user with id: " + userId, e);
}
logger.error("No tokens found for user with id: " + userId);
return 0;
}
}

View file

@ -4,11 +4,11 @@ import java.sql.ResultSet;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.Iterator; import java.util.Iterator;
public final class ResultSetIterator<T> implements Iterator<T> { public class ResultSetIterator<T> implements Iterator<T> {
private final ResultSet resultSet; private final ResultSet resultSet;
private final ResultSetTransformer<T> transformer; private final ResultSetTransformer<T> transformer;
public ResultSetIterator(final ResultSet resultSet, final ResultSetTransformer<T> transformer) { public ResultSetIterator(ResultSet resultSet, ResultSetTransformer<T> transformer) {
this.resultSet = resultSet; this.resultSet = resultSet;
this.transformer = transformer; this.transformer = transformer;
} }

View file

@ -1,7 +1,7 @@
package de.hhhammer.dchat.db; package de.hhhammer.dchat.db;
public final class ResultSetIteratorException extends RuntimeException { public class ResultSetIteratorException extends RuntimeException {
public ResultSetIteratorException(final Throwable cause) { public ResultSetIteratorException(Throwable cause) {
super(cause); super(cause);
} }
} }

View file

@ -2,26 +2,196 @@ package de.hhhammer.dchat.db;
import de.hhhammer.dchat.db.models.server.ServerConfig; import de.hhhammer.dchat.db.models.server.ServerConfig;
import de.hhhammer.dchat.db.models.server.ServerMessage; import de.hhhammer.dchat.db.models.server.ServerMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.sql.DataSource;
import java.sql.*;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.stream.StreamSupport;
public interface ServerDBService { public class ServerDBService {
Optional<ServerConfig> getConfig(String serverId); private static final Logger logger = LoggerFactory.getLogger(ServerDBService.class);
private final DataSource dataSource;
List<ServerConfig> getAllConfigs() throws DBException; public ServerDBService(DataSource dataSource) {
this.dataSource = dataSource;
Optional<ServerConfig> getConfigBy(long id) throws DBException; }
void addConfig(ServerConfig.NewServerConfig newServerConfig) throws DBException; public Optional<ServerConfig> getConfig(String serverId) {
var getServerConfig = """
void updateConfig(long id, ServerConfig.NewServerConfig newServerConfig) throws DBException; SELECT * FROM server_configs WHERE server_id = ?
""";
void deleteConfig(long id) throws DBException; try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
int countMessagesInLastMinute(String serverId); ) {
pstmt.setString(1, serverId);
void addMessage(ServerMessage.NewServerMessage serverMessage); ResultSet resultSet = pstmt.executeQuery();
Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
long tokensOfLast30Days(String serverId); return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (SQLException e) {
logger.error("Getting configuration for server with id: " + serverId, e);
} catch (ResultSetIteratorException e) {
logger.error("Iterating over ServerConfig ResultSet for server with id: " + serverId, e);
return Optional.empty();
}
return Optional.empty();
}
public List<ServerConfig> getAllConfigs() throws DBException {
var getAllowedServerSql = """
SELECT * FROM server_configs
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getAllowedServerSql)
) {
ResultSet resultSet = pstmt.executeQuery();
Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).toList();
} catch (SQLException e) {
throw new DBException("Loading all configs", e);
} catch (ResultSetIteratorException e) {
throw new DBException("Iterating over configs", e);
}
}
public Optional<ServerConfig> getConfigBy(long id) throws DBException {
var getServerConfig = """
SELECT * FROM server_configs WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
ResultSet resultSet = pstmt.executeQuery();
Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (SQLException e) {
throw new DBException("Getting configuration with id: " + id, e);
} catch (ResultSetIteratorException e) {
throw new DBException("Iterating over ServerConfig ResultSet for id: " + id, e);
}
}
public void addConfig(ServerConfig.NewServerConfig newServerConfig) throws DBException {
var getServerConfig = """
INSERT INTO server_configs (server_id, system_message, rate_limit) VALUES (?,?,?)
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newServerConfig.serverId());
pstmt.setString(2, newServerConfig.systemMessage());
pstmt.setInt(3, newServerConfig.rateLimit());
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config added for server with id: " + newServerConfig.serverId());
}
} catch (SQLException e) {
throw new DBException("Adding configuration to server with id: " + newServerConfig.serverId(), e);
}
}
public void updateConfig(long id, ServerConfig.NewServerConfig newServerConfig) throws DBException {
var getServerConfig = """
UPDATE server_configs SET system_message = ?, rate_limit = ?, server_id = ? WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newServerConfig.systemMessage());
pstmt.setInt(2, newServerConfig.rateLimit());
pstmt.setString(3, newServerConfig.serverId());
pstmt.setLong(4, id);
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config update for server with id: " + newServerConfig.serverId());
}
} catch (SQLException e) {
throw new DBException("Adding configuration to server with id: " + newServerConfig.serverId(), e);
}
}
public void deleteConfig(long id) throws DBException {
var getServerConfig = """
DELETE FROM server_configs WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config deleted for server with id: " + id);
}
} catch (SQLException e) {
throw new DBException("Deleting configuration for server with id: " + id, e);
}
}
public int countMessagesInLastMinute(String serverId) {
var getServerConfig = """
SELECT count(*) FROM server_messages WHERE server_id = ? AND time <= ? and time >= ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, serverId);
var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(1, ChronoUnit.MINUTES)));
ResultSet resultSet = pstmt.executeQuery();
resultSet.next();
return resultSet.getInt(1);
} catch (SQLException e) {
logger.error("Getting messages for server with id: " + serverId, e);
} catch (ResultSetIteratorException e) {
logger.error("Iterating over ServerMessages ResultSet for server with id: " + serverId, e);
return Integer.MAX_VALUE;
}
return Integer.MAX_VALUE;
}
public void addMessage(ServerMessage.NewServerMessage serverMessage) {
var getServerConfig = """
INSERT INTO server_messages (server_id, user_id, tokens) VALUES (?,?,?)
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, serverMessage.serverId());
pstmt.setLong(2, serverMessage.userId());
pstmt.setInt(3, serverMessage.tokens());
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No message added for server with id: " + serverMessage.serverId());
}
} catch (SQLException e) {
logger.error("Adding message to server with id: " + serverMessage.serverId(), e);
}
}
public long tokensOfLast30Days(String serverId) {
var countTokensOfLast30Days = """
SELECT sum(tokens) FROM server_messages WHERE server_id = ? AND time < ? AND time >= ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
) {
pstmt.setString(1, serverId);
var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(30, ChronoUnit.DAYS)));
ResultSet resultSet = pstmt.executeQuery();
resultSet.next();
return resultSet.getLong(1);
} catch (SQLException e) {
logger.error("Counting tokens of the last 30 days from server with id: " + serverId, e);
}
logger.error("No tokens found for server with id: " + serverId);
return 0;
}
} }

View file

@ -2,28 +2,219 @@ package de.hhhammer.dchat.db;
import de.hhhammer.dchat.db.models.user.UserConfig; import de.hhhammer.dchat.db.models.user.UserConfig;
import de.hhhammer.dchat.db.models.user.UserMessage; import de.hhhammer.dchat.db.models.user.UserMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.sql.DataSource;
import java.sql.*;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.stream.StreamSupport;
public interface UserDBService { public class UserDBService {
Optional<UserConfig> getConfig(String userId); private static final Logger logger = LoggerFactory.getLogger(UserDBService.class);
private final DataSource dataSource;
Optional<UserConfig> getConfigBy(long id) throws DBException; public UserDBService(DataSource dataSource) {
this.dataSource = dataSource;
List<UserConfig> getAllConfigs() throws DBException; }
void addConfig(UserConfig.NewUserConfig newUserConfig) throws DBException; public Optional<UserConfig> getConfig(String userId) {
var getServerConfig = """
void updateConfig(long id, UserConfig.NewUserConfig newUserConfig) throws DBException; SELECT * FROM user_configs WHERE user_id = ?
""";
void deleteConfig(long id) throws DBException; try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
int countMessagesInLastMinute(String userId); ) {
pstmt.setString(1, userId);
void addMessage(UserMessage.NewUserMessage newUserMessage); ResultSet resultSet = pstmt.executeQuery();
Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
List<UserMessage> getLastMessages(String userId, int limit); return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (SQLException e) {
long tokensOfLast30Days(String userId); logger.error("Getting configuration for user with id: " + userId, e);
} catch (ResultSetIteratorException e) {
logger.error("Iterating over ServerConfig ResultSet for user with id: " + userId, e);
return Optional.empty();
}
return Optional.empty();
}
public Optional<UserConfig> getConfigBy(long id) throws DBException {
var getServerConfig = """
SELECT * FROM user_configs WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
ResultSet resultSet = pstmt.executeQuery();
Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (SQLException e) {
throw new DBException("Getting configuration id: " + id, e);
} catch (ResultSetIteratorException e) {
throw new DBException("Iterating over UserConfig ResultSet with id: " + id, e);
}
}
public List<UserConfig> getAllConfigs() throws DBException {
var getServerConfig = """
SELECT * FROM user_configs
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
ResultSet resultSet = pstmt.executeQuery();
Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).toList();
} catch (SQLException e) {
throw new DBException("Getting all configurations", e);
} catch (ResultSetIteratorException e) {
throw new DBException("Iterating over all UserConfig ResultSet", e);
}
}
public void addConfig(UserConfig.NewUserConfig newUserConfig) throws DBException {
var getServerConfig = """
INSERT INTO user_configs (user_id, system_message, context_length, rate_limit) VALUES (?,?,?,?)
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserConfig.userId());
pstmt.setString(2, newUserConfig.systemMessage());
pstmt.setInt(3, newUserConfig.contextLength());
pstmt.setInt(4, newUserConfig.rateLimit());
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config added for user with id: " + newUserConfig.userId());
}
} catch (SQLException e) {
throw new DBException("Adding configuration for user with id: " + newUserConfig.userId(), e);
}
}
public void updateConfig(long id, UserConfig.NewUserConfig newUserConfig) throws DBException {
var getServerConfig = """
UPDATE user_configs SET system_message = ?, context_length = ?, rate_limit = ?, user_id = ? WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserConfig.systemMessage());
pstmt.setInt(2, newUserConfig.rateLimit());
pstmt.setLong(3, newUserConfig.contextLength());
pstmt.setString(4, newUserConfig.userId());
pstmt.setLong(5, id);
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config update with id: " + id);
}
} catch (SQLException e) {
throw new DBException("Updating configuration with id: " + id, e);
}
}
public void deleteConfig(long id) throws DBException {
var getServerConfig = """
DELETE FROM user_configs WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config deleted for user with id: " + id);
}
} catch (SQLException e) {
throw new DBException("Deleting configuration with id: " + id, e);
}
}
public int countMessagesInLastMinute(String userId) {
var getServerConfig = """
SELECT count(*) FROM user_messages WHERE user_id = ? AND time <= ? and time >= ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, userId);
var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(1, ChronoUnit.MINUTES)));
ResultSet resultSet = pstmt.executeQuery();
resultSet.next();
return resultSet.getInt(1);
} catch (SQLException e) {
logger.error("Getting messages for user with id: " + userId, e);
} catch (ResultSetIteratorException e) {
logger.error("Iterating over ServerMessages ResultSet for user with id: " + userId, e);
return Integer.MAX_VALUE;
}
return Integer.MAX_VALUE;
}
public void addMessage(UserMessage.NewUserMessage newUserMessage) {
var getServerConfig = """
INSERT INTO user_messages (user_id, question, answer, tokens) VALUES (?,?,?,?)
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserMessage.userId());
pstmt.setString(2, newUserMessage.question());
pstmt.setString(3, newUserMessage.answer());
pstmt.setInt(4, newUserMessage.tokens());
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No message added for user with id: " + newUserMessage.userId());
}
} catch (SQLException e) {
logger.error("Adding message to user with id: " + newUserMessage.userId(), e);
}
}
public List<UserMessage> getLastMessages(String userId, int limit) {
var getLastMessages = """
SELECT * FROM user_messages WHERE user_id = ? ORDER BY time DESC LIMIT ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getLastMessages)
) {
pstmt.setString(1, userId);
pstmt.setInt(2, limit);
ResultSet resultSet = pstmt.executeQuery();
Iterable<UserMessage> iterable = () -> new ResultSetIterator<>(resultSet, new UserMessage.UserMessageResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).toList();
} catch (SQLException e) {
logger.error("Fetching last messages for user whit id: " + userId, e);
} catch (ResultSetIteratorException e) {
logger.error("Iterating over messages ResultSet from user with id: " + userId, e);
}
return List.of();
}
public long tokensOfLast30Days(String userId) {
var countTokensOfLast30Days = """
SELECT sum(tokens) FROM user_messages WHERE user_id = ? AND time < ? AND time >= ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
) {
pstmt.setString(1, userId);
var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(30, ChronoUnit.DAYS)));
ResultSet resultSet = pstmt.executeQuery();
resultSet.next();
return resultSet.getLong(1);
} catch (SQLException e) {
logger.error("Counting tokens of the last 30 days from user with id: " + userId, e);
}
logger.error("No tokens found for user with id: " + userId);
return 0;
}
} }

View file

@ -6,20 +6,20 @@ import org.slf4j.LoggerFactory;
/** /**
* Hello world! * Hello world!
*/ */
public final class App { public class App {
private static final Logger logger = LoggerFactory.getLogger(App.class); private static final Logger logger = LoggerFactory.getLogger(App.class);
private static final String DB_MIGRATION_PATH = "db/schema.sql"; private static final String DB_MIGRATION_PATH = "db/schema.sql";
public static void main(final String[] args) { public static void main(String[] args) {
final String postgresUser = System.getenv("POSTGRES_USER"); String postgresUser = System.getenv("POSTGRES_USER");
final String postgresPassword = System.getenv("POSTGRES_PASSWORD"); String postgresPassword = System.getenv("POSTGRES_PASSWORD");
final String postgresUrl = System.getenv("POSTGRES_URL"); String postgresUrl = System.getenv("POSTGRES_URL");
if (postgresUser == null || postgresPassword == null || postgresUrl == null) { if (postgresUser == null || postgresPassword == null || postgresUrl == null) {
logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL"); logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL");
System.exit(1); System.exit(1);
} }
final var migrationExecutor = new MigrationExecutor(postgresUrl, postgresUser, postgresPassword); var migrationExecutor = new MigrationExecutor(postgresUrl, postgresUser, postgresPassword);
final var dbMigrator = new DBMigrator(migrationExecutor, DB_MIGRATION_PATH); var dbMigrator = new DBMigrator(migrationExecutor, DB_MIGRATION_PATH);
dbMigrator.run(); dbMigrator.run();
} }
} }

View file

@ -1,7 +1,7 @@
package de.hhhammer.dchat.migration; package de.hhhammer.dchat.migration;
public final class DBMigrationException extends Exception { public class DBMigrationException extends Exception {
public DBMigrationException(final Throwable cause) { public DBMigrationException(Throwable cause) {
super(cause); super(cause);
} }
} }

View file

@ -6,12 +6,12 @@ import org.slf4j.LoggerFactory;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
public final class DBMigrator implements Runnable { public class DBMigrator implements Runnable {
private static final Logger logger = LoggerFactory.getLogger(DBMigrator.class); private static final Logger logger = LoggerFactory.getLogger(DBMigrator.class);
private final MigrationExecutor migrationExecutor; private final MigrationExecutor migrationExecutor;
private final String resourcePath; private final String resourcePath;
public DBMigrator(final MigrationExecutor migrationExecutor, final String resourcePath) { public DBMigrator(MigrationExecutor migrationExecutor, String resourcePath) {
this.migrationExecutor = migrationExecutor; this.migrationExecutor = migrationExecutor;
this.resourcePath = resourcePath; this.resourcePath = resourcePath;
} }
@ -19,8 +19,8 @@ public final class DBMigrator implements Runnable {
@Override @Override
public void run() { public void run() {
logger.info("Starting db migration"); logger.info("Starting db migration");
final ClassLoader classLoader = getClass().getClassLoader(); ClassLoader classLoader = getClass().getClassLoader();
try (final InputStream inputStream = classLoader.getResourceAsStream(this.resourcePath)) { try (InputStream inputStream = classLoader.getResourceAsStream(this.resourcePath)) {
if (inputStream == null) { if (inputStream == null) {
logger.error("Migration file not found: " + resourcePath); logger.error("Migration file not found: " + resourcePath);
throw new RuntimeException("Migration file not found"); throw new RuntimeException("Migration file not found");

View file

@ -11,24 +11,24 @@ import java.sql.DriverManager;
import java.sql.SQLException; import java.sql.SQLException;
import java.sql.Statement; import java.sql.Statement;
public final class MigrationExecutor { public class MigrationExecutor {
private static final Logger logger = LoggerFactory.getLogger(MigrationExecutor.class); private static final Logger logger = LoggerFactory.getLogger(MigrationExecutor.class);
private final String jdbcConnectionString; private final String jdbcConnectionString;
private final String username; private final String username;
private final String password; private final String password;
public MigrationExecutor(final String jdbcConnectionString, final String username, final String password) { public MigrationExecutor(String jdbcConnectionString, String username, String password) {
this.jdbcConnectionString = jdbcConnectionString; this.jdbcConnectionString = jdbcConnectionString;
this.username = username; this.username = username;
this.password = password; this.password = password;
} }
public void migrate(final InputStream input) throws DBMigrationException { public void migrate(InputStream input) throws DBMigrationException {
try (final Connection con = DriverManager try (Connection con = DriverManager
.getConnection(this.jdbcConnectionString, this.username, this.password); .getConnection(this.jdbcConnectionString, this.username, this.password);
final Statement stmp = con.createStatement(); Statement stmp = con.createStatement();
) { ) {
final String content = new String(input.readAllBytes(), StandardCharsets.UTF_8); String content = new String(input.readAllBytes(), StandardCharsets.UTF_8);
stmp.execute(content); stmp.execute(content);
} catch (SQLException | IOException e) { } catch (SQLException | IOException e) {
throw new DBMigrationException(e); throw new DBMigrationException(e);

View file

@ -12,6 +12,10 @@
<version>1.0-SNAPSHOT</version> <version>1.0-SNAPSHOT</version>
<name>web</name> <name>web</name>
<properties>
<node.version>v18.16.0</node.version>
</properties>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>de.hhhammer.dchat</groupId> <groupId>de.hhhammer.dchat</groupId>

View file

@ -2,41 +2,41 @@ package de.hhhammer.dchat.web;
import com.zaxxer.hikari.HikariConfig; import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource; import com.zaxxer.hikari.HikariDataSource;
import de.hhhammer.dchat.db.PostgresServerDBService; import de.hhhammer.dchat.db.ServerDBService;
import de.hhhammer.dchat.db.PostgresUserDBService; import de.hhhammer.dchat.db.UserDBService;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
/** /**
* Hello world! * Hello world!
*/ */
public final class App { public class App {
private static final Logger logger = LoggerFactory.getLogger(App.class); private static final Logger logger = LoggerFactory.getLogger(App.class);
public static void main(final String[] args) { public static void main(String[] args) {
final String postgresUser = System.getenv("POSTGRES_USER"); String postgresUser = System.getenv("POSTGRES_USER");
final String postgresPassword = System.getenv("POSTGRES_PASSWORD"); String postgresPassword = System.getenv("POSTGRES_PASSWORD");
final String postgresUrl = System.getenv("POSTGRES_URL"); String postgresUrl = System.getenv("POSTGRES_URL");
if (postgresUser == null || postgresPassword == null || postgresUrl == null) { if (postgresUser == null || postgresPassword == null || postgresUrl == null) {
logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL"); logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL");
System.exit(1); System.exit(1);
} }
final String apiPortStr = System.getenv("API_PORT") != null ? System.getenv("API_PORT") : "8080"; String apiPortStr = System.getenv("API_PORT") != null ? System.getenv("API_PORT") : "8080";
final int apiPort = Integer.parseInt(apiPortStr); int apiPort = Integer.parseInt(apiPortStr);
final boolean debug = "true".equals(System.getenv("API_DEBUG")); boolean debug = "true".equals(System.getenv("API_DEBUG"));
final var config = new HikariConfig(); var config = new HikariConfig();
config.setJdbcUrl(postgresUrl); config.setJdbcUrl(postgresUrl);
config.setUsername(postgresUser); config.setUsername(postgresUser);
config.setPassword(postgresPassword); config.setPassword(postgresPassword);
try (final var ds = new HikariDataSource(config)) { try (var ds = new HikariDataSource(config)) {
final var serverDBService = new PostgresServerDBService(ds); var serverDBService = new ServerDBService(ds);
final var userDBService = new PostgresUserDBService(ds); var userDBService = new UserDBService(ds);
final var appConfig = new AppConfig(apiPort, debug); var appConfig = new AppConfig(apiPort, debug);
final var webApi = new WebAPI(serverDBService, userDBService, appConfig); var webApi = new WebAPI(serverDBService, userDBService, appConfig);
webApi.run(); webApi.run();
} }
} }

View file

@ -6,6 +6,7 @@ import de.hhhammer.dchat.web.server.ConfigCrudHandler;
import de.hhhammer.dchat.web.user.ConfigUserCrudHandler; import de.hhhammer.dchat.web.user.ConfigUserCrudHandler;
import io.javalin.Javalin; import io.javalin.Javalin;
import io.javalin.http.HttpStatus; import io.javalin.http.HttpStatus;
import io.javalin.http.staticfiles.Location;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -15,13 +16,13 @@ import java.util.concurrent.ExecutionException;
import static io.javalin.apibuilder.ApiBuilder.crud; import static io.javalin.apibuilder.ApiBuilder.crud;
import static io.javalin.apibuilder.ApiBuilder.path; import static io.javalin.apibuilder.ApiBuilder.path;
public final class WebAPI implements Runnable { public class WebAPI implements Runnable {
private static final Logger logger = LoggerFactory.getLogger(WebAPI.class); private static final Logger logger = LoggerFactory.getLogger(WebAPI.class);
private final ServerDBService serverDBService; private final ServerDBService serverDBService;
private final UserDBService userDBService; private final UserDBService userDBService;
private final AppConfig appConfig; private final AppConfig appConfig;
public WebAPI(final ServerDBService serverDBService, final UserDBService userDBService, final AppConfig appConfig) { public WebAPI(ServerDBService serverDBService, UserDBService userDBService, AppConfig appConfig) {
this.serverDBService = serverDBService; this.serverDBService = serverDBService;
this.userDBService = userDBService; this.userDBService = userDBService;
this.appConfig = appConfig; this.appConfig = appConfig;
@ -30,12 +31,12 @@ public final class WebAPI implements Runnable {
@Override @Override
public void run() { public void run() {
logger.info("Starting web application"); logger.info("Starting web application");
final Javalin app = Javalin.create(config -> { var app = Javalin.create(config -> {
if (appConfig.debug()) config.plugins.enableDevLogging(); if (appConfig.debug()) config.plugins.enableDevLogging();
config.http.prefer405over404 = true; // return 405 instead of 404 if path is mapped to different HTTP method config.http.prefer405over404 = true; // return 405 instead of 404 if path is mapped to different HTTP method
config.http.defaultContentType = "application/json"; config.http.defaultContentType = "application/json";
}); });
final var waitForShutdown = new CompletableFuture<Void>(); var waitForShutdown = new CompletableFuture<Void>();
Runtime.getRuntime().addShutdownHook(new Thread(() -> { Runtime.getRuntime().addShutdownHook(new Thread(() -> {
logger.info("Shutting down web application"); logger.info("Shutting down web application");
app.stop(); app.stop();

View file

@ -10,21 +10,18 @@ import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.util.List; public class ConfigCrudHandler implements CrudHandler {
import java.util.Optional;
public final class ConfigCrudHandler implements CrudHandler {
private static final Logger logger = LoggerFactory.getLogger(ConfigCrudHandler.class); private static final Logger logger = LoggerFactory.getLogger(ConfigCrudHandler.class);
private final ServerDBService serverDBService; private final ServerDBService serverDBService;
public ConfigCrudHandler(final ServerDBService serverDBService) { public ConfigCrudHandler(ServerDBService serverDBService) {
this.serverDBService = serverDBService; this.serverDBService = serverDBService;
} }
@Override @Override
public void create(@NotNull final Context context) { public void create(@NotNull Context context) {
final ServerConfig.NewServerConfig body = context.bodyAsClass(ServerConfig.NewServerConfig.class); var body = context.bodyAsClass(ServerConfig.NewServerConfig.class);
try { try {
this.serverDBService.addConfig(body); this.serverDBService.addConfig(body);
} catch (DBException e) { } catch (DBException e) {
@ -35,10 +32,9 @@ public final class ConfigCrudHandler implements CrudHandler {
} }
@Override @Override
public void delete(@NotNull final Context context, @NotNull final String s) { public void delete(@NotNull Context context, @NotNull String s) {
final var id = Long.parseLong(s);
try { try {
this.serverDBService.deleteConfig(id); this.serverDBService.deleteConfig(Long.parseLong(s));
context.status(HttpStatus.NO_CONTENT); context.status(HttpStatus.NO_CONTENT);
} catch (DBException e) { } catch (DBException e) {
logger.error("Deleting configuration with id: " + s, e); logger.error("Deleting configuration with id: " + s, e);
@ -47,9 +43,9 @@ public final class ConfigCrudHandler implements CrudHandler {
} }
@Override @Override
public void getAll(@NotNull final Context context) { public void getAll(@NotNull Context context) {
try { try {
final List<ServerConfig> allowedServers = this.serverDBService.getAllConfigs(); var allowedServers = this.serverDBService.getAllConfigs();
context.json(allowedServers); context.json(allowedServers);
} catch (DBException e) { } catch (DBException e) {
logger.error("Getting all server configs", e); logger.error("Getting all server configs", e);
@ -58,10 +54,10 @@ public final class ConfigCrudHandler implements CrudHandler {
} }
@Override @Override
public void getOne(@NotNull final Context context, @NotNull final String s) { public void getOne(@NotNull Context context, @NotNull String s) {
final var id = Long.parseLong(s); var id = Long.parseLong(s);
try { try {
final Optional<ServerConfig> server = this.serverDBService.getConfigBy(id); var server = this.serverDBService.getConfigBy(id);
if (server.isEmpty()) { if (server.isEmpty()) {
context.status(HttpStatus.NOT_FOUND); context.status(HttpStatus.NOT_FOUND);
return; return;
@ -74,9 +70,9 @@ public final class ConfigCrudHandler implements CrudHandler {
} }
@Override @Override
public void update(@NotNull final Context context, @NotNull final String idString) { public void update(@NotNull Context context, @NotNull String idString) {
final ServerConfig.NewServerConfig body = context.bodyAsClass(ServerConfig.NewServerConfig.class); var body = context.bodyAsClass(ServerConfig.NewServerConfig.class);
final var id = Long.parseLong(idString); var id = Long.parseLong(idString);
try { try {
this.serverDBService.updateConfig(id, body); this.serverDBService.updateConfig(id, body);

View file

@ -10,21 +10,18 @@ import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.util.List; public class ConfigUserCrudHandler implements CrudHandler {
import java.util.Optional;
public final class ConfigUserCrudHandler implements CrudHandler {
private static final Logger logger = LoggerFactory.getLogger(ConfigUserCrudHandler.class); private static final Logger logger = LoggerFactory.getLogger(ConfigUserCrudHandler.class);
private final UserDBService userDBService; private final UserDBService userDBService;
public ConfigUserCrudHandler(final UserDBService userDBService) { public ConfigUserCrudHandler(UserDBService userDBService) {
this.userDBService = userDBService; this.userDBService = userDBService;
} }
@Override @Override
public void create(@NotNull final Context context) { public void create(@NotNull Context context) {
final UserConfig.NewUserConfig body = context.bodyAsClass(UserConfig.NewUserConfig.class); var body = context.bodyAsClass(UserConfig.NewUserConfig.class);
try { try {
this.userDBService.addConfig(body); this.userDBService.addConfig(body);
} catch (DBException e) { } catch (DBException e) {
@ -36,10 +33,9 @@ public final class ConfigUserCrudHandler implements CrudHandler {
} }
@Override @Override
public void delete(@NotNull final Context context, @NotNull final String s) { public void delete(@NotNull Context context, @NotNull String s) {
final var id = Long.parseLong(s);
try { try {
this.userDBService.deleteConfig(id); this.userDBService.deleteConfig(Long.parseLong(s));
context.status(HttpStatus.NO_CONTENT); context.status(HttpStatus.NO_CONTENT);
} catch (DBException e) { } catch (DBException e) {
logger.error("Deleting configuration with id: " + s, e); logger.error("Deleting configuration with id: " + s, e);
@ -48,9 +44,9 @@ public final class ConfigUserCrudHandler implements CrudHandler {
} }
@Override @Override
public void getAll(@NotNull final Context context) { public void getAll(@NotNull Context context) {
try { try {
final List<UserConfig> allowedServers = this.userDBService.getAllConfigs(); var allowedServers = this.userDBService.getAllConfigs();
context.json(allowedServers); context.json(allowedServers);
} catch (DBException e) { } catch (DBException e) {
logger.error("Getting all user configs", e); logger.error("Getting all user configs", e);
@ -59,10 +55,10 @@ public final class ConfigUserCrudHandler implements CrudHandler {
} }
@Override @Override
public void getOne(@NotNull final Context context, @NotNull final String s) { public void getOne(@NotNull Context context, @NotNull String s) {
final var id = Long.parseLong(s); var id = Long.parseLong(s);
try { try {
final Optional<UserConfig> server = this.userDBService.getConfigBy(id); var server = this.userDBService.getConfigBy(id);
if (server.isEmpty()) { if (server.isEmpty()) {
context.status(HttpStatus.NOT_FOUND); context.status(HttpStatus.NOT_FOUND);
return; return;
@ -76,8 +72,8 @@ public final class ConfigUserCrudHandler implements CrudHandler {
@Override @Override
public void update(@NotNull Context context, @NotNull String idString) { public void update(@NotNull Context context, @NotNull String idString) {
final UserConfig.NewUserConfig body = context.bodyAsClass(UserConfig.NewUserConfig.class); var body = context.bodyAsClass(UserConfig.NewUserConfig.class);
final var id = Long.parseLong(idString); var id = Long.parseLong(idString);
try { try {
this.userDBService.updateConfig(id, body); this.userDBService.updateConfig(id, body);