Compare commits

...

5 commits

Author SHA1 Message Date
168a5d6c81 web: Remove old variable for the nodejs plugin
All checks were successful
ci/woodpecker/push/java/1 Pipeline was successful
ci/woodpecker/push/java/2 Pipeline was successful
ci/woodpecker/push/java/3 Pipeline was successful
ci/woodpecker/push/java/4 Pipeline was successful
ci/woodpecker/push/nodejs Pipeline was successful
ci/woodpecker/push/oci-image-build/1 Pipeline was successful
ci/woodpecker/push/oci-image-build/2 Pipeline was successful
ci/woodpecker/push/oci-image-build/3 Pipeline was successful
ci/woodpecker/push/oci-image-build/4 Pipeline was successful
ci/woodpecker/tag/java/1 Pipeline was successful
ci/woodpecker/tag/java/2 Pipeline was successful
ci/woodpecker/tag/java/3 Pipeline was successful
ci/woodpecker/tag/java/4 Pipeline was successful
ci/woodpecker/tag/nodejs Pipeline was successful
ci/woodpecker/tag/oci-image-build/1 Pipeline was successful
ci/woodpecker/tag/oci-image-build/2 Pipeline was successful
ci/woodpecker/tag/oci-image-build/3 Pipeline was successful
ci/woodpecker/tag/oci-image-build/4 Pipeline was successful
2024-01-25 20:35:07 +01:00
b77836effb misc: Make all variables final 2024-01-25 20:32:30 +01:00
e58980cea3 misc: Make classes final 2024-01-25 19:04:04 +01:00
ca95bb45cb misc: Reformat and optimise imports 2024-01-25 18:58:31 +01:00
ab48afd5ed db: Introduce interface to enable testability 2024-01-25 18:56:51 +01:00
24 changed files with 642 additions and 562 deletions

View file

@ -4,47 +4,47 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.zaxxer.hikari.HikariConfig; import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource; import com.zaxxer.hikari.HikariDataSource;
import de.hhhammer.dchat.bot.openai.ChatGPTService; import de.hhhammer.dchat.bot.openai.ChatGPTService;
import de.hhhammer.dchat.db.ServerDBService; import de.hhhammer.dchat.db.PostgresServerDBService;
import de.hhhammer.dchat.db.UserDBService; import de.hhhammer.dchat.db.PostgresUserDBService;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.net.http.HttpClient; import java.net.http.HttpClient;
public class App { public final class App {
private static final Logger logger = LoggerFactory.getLogger(App.class); private static final Logger logger = LoggerFactory.getLogger(App.class);
public static void main(String[] args) { public static void main(final String[] args) {
String discordApiKey = System.getenv("DISCORD_API_KEY"); final String discordApiKey = System.getenv("DISCORD_API_KEY");
if (discordApiKey == null) { if (discordApiKey == null) {
logger.error("Missing environment variables: DISCORD_API_KEY"); logger.error("Missing environment variables: DISCORD_API_KEY");
System.exit(1); System.exit(1);
} }
String openaiApiKey = System.getenv("OPENAI_API_KEY"); final String openaiApiKey = System.getenv("OPENAI_API_KEY");
if (openaiApiKey == null) { if (openaiApiKey == null) {
logger.error("Missing environment variables: OPENAI_API_KEY"); logger.error("Missing environment variables: OPENAI_API_KEY");
System.exit(1); System.exit(1);
} }
String postgresUser = System.getenv("POSTGRES_USER"); final String postgresUser = System.getenv("POSTGRES_USER");
String postgresPassword = System.getenv("POSTGRES_PASSWORD"); final String postgresPassword = System.getenv("POSTGRES_PASSWORD");
String postgresUrl = System.getenv("POSTGRES_URL"); final String postgresUrl = System.getenv("POSTGRES_URL");
if (postgresUser == null || postgresPassword == null || postgresUrl == null) { if (postgresUser == null || postgresPassword == null || postgresUrl == null) {
logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL"); logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL");
System.exit(1); System.exit(1);
} }
var chatGPTService = new ChatGPTService(openaiApiKey, HttpClient.newHttpClient(), new ObjectMapper()); final var chatGPTService = new ChatGPTService(openaiApiKey, HttpClient.newHttpClient(), new ObjectMapper());
var config = new HikariConfig(); final var config = new HikariConfig();
config.setJdbcUrl(postgresUrl); config.setJdbcUrl(postgresUrl);
config.setUsername(postgresUser); config.setUsername(postgresUser);
config.setPassword(postgresPassword); config.setPassword(postgresPassword);
try (var ds = new HikariDataSource(config)) { try (var ds = new HikariDataSource(config)) {
var serverDBService = new ServerDBService(ds); final var serverDBService = new PostgresServerDBService(ds);
var userDBService = new UserDBService(ds); final var userDBService = new PostgresUserDBService(ds);
var discordBot = new DiscordBot(serverDBService, userDBService, chatGPTService, discordApiKey); final var discordBot = new DiscordBot(serverDBService, userDBService, chatGPTService, discordApiKey);
discordBot.run(); discordBot.run();
} }
} }

View file

@ -9,6 +9,7 @@ import de.hhhammer.dchat.db.UserDBService;
import org.javacord.api.DiscordApi; import org.javacord.api.DiscordApi;
import org.javacord.api.DiscordApiBuilder; import org.javacord.api.DiscordApiBuilder;
import org.javacord.api.interaction.SlashCommand; import org.javacord.api.interaction.SlashCommand;
import org.javacord.api.interaction.SlashCommandInteraction;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -16,7 +17,7 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
public class DiscordBot implements Runnable { public final class DiscordBot implements Runnable {
private static final Logger logger = LoggerFactory.getLogger(DiscordBot.class); private static final Logger logger = LoggerFactory.getLogger(DiscordBot.class);
private final ServerDBService serverDBService; private final ServerDBService serverDBService;
@ -24,7 +25,7 @@ public class DiscordBot implements Runnable {
private final ChatGPTService chatGPTService; private final ChatGPTService chatGPTService;
private final String discordApiKey; private final String discordApiKey;
public DiscordBot(ServerDBService serverDBService, UserDBService userDBService, ChatGPTService chatGPTService, String discordApiKey) { public DiscordBot(final ServerDBService serverDBService, final UserDBService userDBService, final ChatGPTService chatGPTService, final String discordApiKey) {
this.serverDBService = serverDBService; this.serverDBService = serverDBService;
this.userDBService = userDBService; this.userDBService = userDBService;
this.chatGPTService = chatGPTService; this.chatGPTService = chatGPTService;
@ -34,29 +35,29 @@ public class DiscordBot implements Runnable {
@Override @Override
public void run() { public void run() {
logger.info("Starting Discord application"); logger.info("Starting Discord application");
DiscordApi discordApi = new DiscordApiBuilder() final DiscordApi discordApi = new DiscordApiBuilder()
.setToken(discordApiKey) .setToken(discordApiKey)
.login() .login()
.join(); .join();
discordApi.setMessageCacheSize(10, 60*60); discordApi.setMessageCacheSize(10, 60 * 60);
var future = new CompletableFuture<Void>(); final var future = new CompletableFuture<Void>();
Runtime.getRuntime().addShutdownHook(Thread.ofVirtual().unstarted(() -> { Runtime.getRuntime().addShutdownHook(Thread.ofVirtual().unstarted(() -> {
logger.info("Shutting down Discord application"); logger.info("Shutting down Discord application");
discordApi.disconnect().thenAccept(future::complete); discordApi.disconnect().thenAccept(future::complete);
})); }));
var token = SlashCommand.with("tokens", "Check how many tokens where spend on this server") final SlashCommand token = SlashCommand.with("tokens", "Check how many tokens where spend on this server")
.createGlobal(discordApi) .createGlobal(discordApi)
.join(); .join();
discordApi.addSlashCommandCreateListener(event -> { discordApi.addSlashCommandCreateListener(event -> {
logger.debug("Event? " + event.getSlashCommandInteraction().getFullCommandName()); logger.debug("Event? " + event.getSlashCommandInteraction().getFullCommandName());
var command = event.getSlashCommandInteraction(); final SlashCommandInteraction command = event.getSlashCommandInteraction();
if (token.getFullCommandNames().contains(command.getFullCommandName())) { if (token.getFullCommandNames().contains(command.getFullCommandName())) {
event.getInteraction() event.getInteraction()
.respondLater() .respondLater()
.orTimeout(30, TimeUnit.SECONDS) .orTimeout(30, TimeUnit.SECONDS)
.thenAccept((interactionOriginalResponseUpdater) -> { .thenAccept((interactionOriginalResponseUpdater) -> {
var tokens = event.getInteraction().getServer().isPresent() ? final long tokens = event.getInteraction().getServer().isPresent() ?
this.serverDBService.tokensOfLast30Days(String.valueOf(event.getInteraction().getServer().get().getId())) : this.serverDBService.tokensOfLast30Days(String.valueOf(event.getInteraction().getServer().get().getId())) :
this.userDBService.tokensOfLast30Days(String.valueOf(event.getInteraction().getUser().getId())); this.userDBService.tokensOfLast30Days(String.valueOf(event.getInteraction().getUser().getId()));
interactionOriginalResponseUpdater.setContent("" + tokens).update(); interactionOriginalResponseUpdater.setContent("" + tokens).update();

View file

@ -9,17 +9,17 @@ import org.slf4j.LoggerFactory;
import java.io.IOException; import java.io.IOException;
public class MessageCreateHandler implements MessageCreateListener { public final class MessageCreateHandler implements MessageCreateListener {
private static final Logger logger = LoggerFactory.getLogger(MessageCreateHandler.class); private static final Logger logger = LoggerFactory.getLogger(MessageCreateHandler.class);
private final MessageHandler messageHandler; private final MessageHandler messageHandler;
public MessageCreateHandler(MessageHandler messageHandler) { public MessageCreateHandler(final MessageHandler messageHandler) {
this.messageHandler = messageHandler; this.messageHandler = messageHandler;
} }
@Override @Override
public void onMessageCreate(MessageCreateEvent event) { public void onMessageCreate(final MessageCreateEvent event) {
Thread.ofVirtual().start(() -> { Thread.ofVirtual().start(() -> {
if (!event.canYouReadContent() || event.getMessageAuthor().isBotUser() || !isNormalOrReplyMessageType(event)) { if (!event.canYouReadContent() || event.getMessageAuthor().isBotUser() || !isNormalOrReplyMessageType(event)) {
return; return;
@ -36,15 +36,15 @@ public class MessageCreateHandler implements MessageCreateListener {
} }
try { try {
this.messageHandler.handle(event); this.messageHandler.handle(event);
} catch (ResponseException | IOException | InterruptedException e) { } catch (final ResponseException | IOException | InterruptedException e) {
logger.error("Reading a message from the listener", e); logger.error("Reading a message from the listener", e);
event.getMessage().reply("Sorry but something went wrong :("); event.getMessage().reply("Sorry but something went wrong :(");
} }
}); });
} }
private boolean isNormalOrReplyMessageType(MessageCreateEvent event) { private boolean isNormalOrReplyMessageType(final MessageCreateEvent event) {
MessageType type = event.getMessage().getType(); final MessageType type = event.getMessage().getType();
return type == MessageType.NORMAL || type == MessageType.REPLY; return type == MessageType.NORMAL || type == MessageType.REPLY;
} }
} }

View file

@ -4,7 +4,10 @@ import de.hhhammer.dchat.bot.openai.ChatGPTRequestBuilder;
import de.hhhammer.dchat.bot.openai.ChatGPTService; import de.hhhammer.dchat.bot.openai.ChatGPTService;
import de.hhhammer.dchat.bot.openai.MessageContext.ReplyInteraction; import de.hhhammer.dchat.bot.openai.MessageContext.ReplyInteraction;
import de.hhhammer.dchat.bot.openai.ResponseException; import de.hhhammer.dchat.bot.openai.ResponseException;
import de.hhhammer.dchat.bot.openai.models.ChatGPTRequest;
import de.hhhammer.dchat.bot.openai.models.ChatGPTResponse;
import de.hhhammer.dchat.db.ServerDBService; import de.hhhammer.dchat.db.ServerDBService;
import de.hhhammer.dchat.db.models.server.ServerConfig;
import de.hhhammer.dchat.db.models.server.ServerMessage; import de.hhhammer.dchat.db.models.server.ServerMessage;
import org.javacord.api.entity.DiscordEntity; import org.javacord.api.entity.DiscordEntity;
import org.javacord.api.entity.message.Message; import org.javacord.api.entity.message.Message;
@ -17,43 +20,44 @@ import org.slf4j.LoggerFactory;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Optional;
public class ServerMessageHandler implements MessageHandler { public final class ServerMessageHandler implements MessageHandler {
private static final Logger logger = LoggerFactory.getLogger(ServerMessageHandler.class); private static final Logger logger = LoggerFactory.getLogger(ServerMessageHandler.class);
private final ServerDBService serverDBService; private final ServerDBService serverDBService;
private final ChatGPTService chatGPTService; private final ChatGPTService chatGPTService;
public ServerMessageHandler(ServerDBService serverDBService, ChatGPTService chatGPTService) { public ServerMessageHandler(final ServerDBService serverDBService, final ChatGPTService chatGPTService) {
this.serverDBService = serverDBService; this.serverDBService = serverDBService;
this.chatGPTService = chatGPTService; this.chatGPTService = chatGPTService;
} }
@Override @Override
public void handle(MessageCreateEvent event) throws ResponseException, IOException, InterruptedException { public void handle(final MessageCreateEvent event) throws ResponseException, IOException, InterruptedException {
String content = extractContent(event); final String content = extractContent(event);
var serverId = event.getServer().get().getId(); final long serverId = event.getServer().get().getId();
var systemMessage = this.serverDBService.getConfig(String.valueOf(serverId)).get().systemMessage(); final String systemMessage = this.serverDBService.getConfig(String.valueOf(serverId)).get().systemMessage();
List<ReplyInteraction> messageContext = event.getMessage().getType() == MessageType.REPLY ? getContextMessages(event) : List.of(); final List<ReplyInteraction> messageContext = event.getMessage().getType() == MessageType.REPLY ? getContextMessages(event) : List.of();
var request = new ChatGPTRequestBuilder().contextRequest(messageContext, content, systemMessage); final ChatGPTRequest request = new ChatGPTRequestBuilder().contextRequest(messageContext, content, systemMessage);
var response = this.chatGPTService.submit(request); final ChatGPTResponse response = this.chatGPTService.submit(request);
if (response.choices().isEmpty()) { if (response.choices().isEmpty()) {
event.getMessage().reply("No response available"); event.getMessage().reply("No response available");
return; return;
} }
var answer = response.choices().get(0).message().content(); final String answer = response.choices().get(0).message().content();
logServerMessage(event, response.usage().totalTokens()); logServerMessage(event, response.usage().totalTokens());
event.getMessage().reply(answer); event.getMessage().reply(answer);
} }
@Override @Override
public boolean isAllowed(MessageCreateEvent event) { public boolean isAllowed(final MessageCreateEvent event) {
if (event.getServer().isEmpty()) { if (event.getServer().isEmpty()) {
return false; return false;
} }
var serverId = event.getServer().get().getId(); final long serverId = event.getServer().get().getId();
var config = this.serverDBService.getConfig(String.valueOf(serverId)); final Optional<ServerConfig> config = this.serverDBService.getConfig(String.valueOf(serverId));
if (config.isEmpty()) { if (config.isEmpty()) {
logger.debug("Not allowed with id: " + serverId); logger.debug("Not allowed with id: " + serverId);
return false; return false;
@ -62,39 +66,39 @@ public class ServerMessageHandler implements MessageHandler {
} }
@Override @Override
public boolean exceedsRate(MessageCreateEvent event) { public boolean exceedsRate(final MessageCreateEvent event) {
var serverId = String.valueOf(event.getServer().get().getId()); final String serverId = String.valueOf(event.getServer().get().getId());
var config = this.serverDBService.getConfig(serverId); final Optional<ServerConfig> config = this.serverDBService.getConfig(serverId);
if (config.isEmpty()) { if (config.isEmpty()) {
logger.error("Missing configuration for server with id: " + serverId); logger.error("Missing configuration for server with id: " + serverId);
return true; return true;
} }
var rateLimit = config.get().rateLimit(); final int rateLimit = config.get().rateLimit();
var countMessagesInLastMinute = this.serverDBService.countMessagesInLastMinute(serverId); final int countMessagesInLastMinute = this.serverDBService.countMessagesInLastMinute(serverId);
return countMessagesInLastMinute >= rateLimit; return countMessagesInLastMinute >= rateLimit;
} }
@Override @Override
public boolean canHandle(MessageCreateEvent event) { public boolean canHandle(final MessageCreateEvent event) {
return event.isServerMessage(); return event.isServerMessage();
} }
private void logServerMessage(MessageCreateEvent event, int tokens) { private void logServerMessage(final MessageCreateEvent event, final int tokens) {
var serverId = event.getServer().map(DiscordEntity::getId).get(); final long serverId = event.getServer().map(DiscordEntity::getId).get();
var userId = event.getMessageAuthor().getId(); final long userId = event.getMessageAuthor().getId();
var serverMessage = new ServerMessage.NewServerMessage(String.valueOf(serverId), userId, tokens); final var serverMessage = new ServerMessage.NewServerMessage(String.valueOf(serverId), userId, tokens);
this.serverDBService.addMessage(serverMessage); this.serverDBService.addMessage(serverMessage);
} }
private String extractContent(MessageCreateEvent event) { private String extractContent(final MessageCreateEvent event) {
long ownId = event.getApi().getYourself().getId(); final long ownId = event.getApi().getYourself().getId();
return event.getMessageContent().replaceFirst("<" + ownId + "> ", ""); return event.getMessageContent().replaceFirst("<" + ownId + "> ", "");
} }
@NotNull @NotNull
private List<ReplyInteraction> getContextMessages(MessageCreateEvent event) { private List<ReplyInteraction> getContextMessages(final MessageCreateEvent event) {
return event.getMessage() return event.getMessage()
.getMessageReference() .getMessageReference()
.map(MessageReference::getMessage) .map(MessageReference::getMessage)

View file

@ -4,7 +4,10 @@ import de.hhhammer.dchat.bot.openai.ChatGPTRequestBuilder;
import de.hhhammer.dchat.bot.openai.ChatGPTService; import de.hhhammer.dchat.bot.openai.ChatGPTService;
import de.hhhammer.dchat.bot.openai.MessageContext.PreviousInteraction; import de.hhhammer.dchat.bot.openai.MessageContext.PreviousInteraction;
import de.hhhammer.dchat.bot.openai.ResponseException; import de.hhhammer.dchat.bot.openai.ResponseException;
import de.hhhammer.dchat.bot.openai.models.ChatGPTRequest;
import de.hhhammer.dchat.bot.openai.models.ChatGPTResponse;
import de.hhhammer.dchat.db.UserDBService; import de.hhhammer.dchat.db.UserDBService;
import de.hhhammer.dchat.db.models.user.UserConfig;
import de.hhhammer.dchat.db.models.user.UserMessage; import de.hhhammer.dchat.db.models.user.UserMessage;
import org.javacord.api.event.message.MessageCreateEvent; import org.javacord.api.event.message.MessageCreateEvent;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -12,46 +15,47 @@ import org.slf4j.LoggerFactory;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Optional;
public class UserMessageHandler implements MessageHandler { public final class UserMessageHandler implements MessageHandler {
private static final Logger logger = LoggerFactory.getLogger(UserMessageHandler.class); private static final Logger logger = LoggerFactory.getLogger(UserMessageHandler.class);
private final UserDBService userDBService; private final UserDBService userDBService;
private final ChatGPTService chatGPTService; private final ChatGPTService chatGPTService;
public UserMessageHandler(UserDBService userDBService, ChatGPTService chatGPTService) { public UserMessageHandler(final UserDBService userDBService, final ChatGPTService chatGPTService) {
this.userDBService = userDBService; this.userDBService = userDBService;
this.chatGPTService = chatGPTService; this.chatGPTService = chatGPTService;
} }
@Override @Override
public void handle(MessageCreateEvent event) throws ResponseException, IOException, InterruptedException { public void handle(final MessageCreateEvent event) throws ResponseException, IOException, InterruptedException {
String content = event.getReadableMessageContent(); final String content = event.getReadableMessageContent();
var userId = String.valueOf(event.getMessageAuthor().getId()); final String userId = String.valueOf(event.getMessageAuthor().getId());
var config = this.userDBService.getConfig(userId).get(); final UserConfig config = this.userDBService.getConfig(userId).get();
var systemMessage = config.systemMessage(); final String systemMessage = config.systemMessage();
List<PreviousInteraction> context = this.userDBService.getLastMessages(userId, config.contextLength()) final List<PreviousInteraction> context = this.userDBService.getLastMessages(userId, config.contextLength())
.stream() .stream()
.map(userMessage -> new PreviousInteraction(userMessage.question(), userMessage.answer())) .map(userMessage -> new PreviousInteraction(userMessage.question(), userMessage.answer()))
.toList(); .toList();
var request = new ChatGPTRequestBuilder().contextRequest(context, content, systemMessage); final ChatGPTRequest request = new ChatGPTRequestBuilder().contextRequest(context, content, systemMessage);
var response = this.chatGPTService.submit(request); final ChatGPTResponse response = this.chatGPTService.submit(request);
if (response.choices().isEmpty()) { if (response.choices().isEmpty()) {
event.getMessage().reply("No response available"); event.getMessage().reply("No response available");
return; return;
} }
var answer = response.choices().get(0).message().content(); final String answer = response.choices().get(0).message().content();
logUserMessage(event, content, answer, response.usage().totalTokens()); logUserMessage(event, content, answer, response.usage().totalTokens());
event.getChannel().sendMessage(answer); event.getChannel().sendMessage(answer);
} }
@Override @Override
public boolean isAllowed(MessageCreateEvent event) { public boolean isAllowed(final MessageCreateEvent event) {
if (event.getServer().isPresent()) { if (event.getServer().isPresent()) {
return false; return false;
} }
var userId = event.getMessageAuthor().getId(); final long userId = event.getMessageAuthor().getId();
var config = this.userDBService.getConfig(String.valueOf(userId)); final Optional<UserConfig> config = this.userDBService.getConfig(String.valueOf(userId));
if (config.isEmpty()) { if (config.isEmpty()) {
logger.debug("Not allowed with id: " + userId); logger.debug("Not allowed with id: " + userId);
return false; return false;
@ -60,28 +64,28 @@ public class UserMessageHandler implements MessageHandler {
} }
@Override @Override
public boolean exceedsRate(MessageCreateEvent event) { public boolean exceedsRate(final MessageCreateEvent event) {
var userId = String.valueOf(event.getMessageAuthor().getId()); final String userId = String.valueOf(event.getMessageAuthor().getId());
var config = this.userDBService.getConfig(userId); final Optional<UserConfig> config = this.userDBService.getConfig(userId);
if (config.isEmpty()) { if (config.isEmpty()) {
logger.error("Missing configuration for userId with id: " + userId); logger.error("Missing configuration for userId with id: " + userId);
return true; return true;
} }
var rateLimit = config.get().rateLimit(); final int rateLimit = config.get().rateLimit();
var countMessagesInLastMinute = this.userDBService.countMessagesInLastMinute(userId); final int countMessagesInLastMinute = this.userDBService.countMessagesInLastMinute(userId);
return countMessagesInLastMinute >= rateLimit; return countMessagesInLastMinute >= rateLimit;
} }
@Override @Override
public boolean canHandle(MessageCreateEvent event) { public boolean canHandle(final MessageCreateEvent event) {
return event.isPrivateMessage(); return event.isPrivateMessage();
} }
private void logUserMessage(MessageCreateEvent event, String question, String answer, int tokens) { private void logUserMessage(final MessageCreateEvent event, final String question, final String answer, final int tokens) {
var userId = event.getMessageAuthor().getId(); final long userId = event.getMessageAuthor().getId();
var userMessage = new UserMessage.NewUserMessage(String.valueOf(userId), question, answer, tokens); final UserMessage.NewUserMessage userMessage = new UserMessage.NewUserMessage(String.valueOf(userId), question, answer, tokens);
this.userDBService.addMessage(userMessage); this.userDBService.addMessage(userMessage);
} }
} }

View file

@ -8,29 +8,19 @@ import org.jetbrains.annotations.NotNull;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
public class ChatGPTRequestBuilder { public final class ChatGPTRequestBuilder {
private static final String smallContextModel = "gpt-3.5-turbo"; private static final String smallContextModel = "gpt-3.5-turbo";
private static final String bigContextModel = "gpt-3.5-turbo-16k"; private static final String bigContextModel = "gpt-3.5-turbo-16k";
public ChatGPTRequestBuilder() { public ChatGPTRequestBuilder() {
} }
public ChatGPTRequest contextRequest(List<? extends MessageContext> contextMessages, String message, String systemMessage) {
List<ChatGPTRequest.Message> messages = getMessages(contextMessages, message, systemMessage);
String contextModel = contextMessages.size() <= 1 ? smallContextModel : bigContextModel;
return new ChatGPTRequest(
contextModel,
messages,
0.7f
);
}
@NotNull @NotNull
private static List<ChatGPTRequest.Message> getMessages(List<? extends MessageContext> contextMessages, String message, String systemMessage) { private static List<ChatGPTRequest.Message> getMessages(final List<? extends MessageContext> contextMessages, final String message, final String systemMessage) {
var systemMsg = new ChatGPTRequest.Message("system", systemMessage); final ChatGPTRequest.Message systemMsg = new ChatGPTRequest.Message("system", systemMessage);
List<ChatGPTRequest.Message> contextMsgs = getContextMessages(contextMessages); final List<ChatGPTRequest.Message> contextMsgs = getContextMessages(contextMessages);
var userMessage = new ChatGPTRequest.Message("user", message); final ChatGPTRequest.Message userMessage = new ChatGPTRequest.Message("user", message);
List<ChatGPTRequest.Message> messages = new ArrayList<>(); final List<ChatGPTRequest.Message> messages = new ArrayList<>();
messages.add(systemMsg); messages.add(systemMsg);
messages.addAll(contextMsgs); messages.addAll(contextMsgs);
messages.add(userMessage); messages.add(userMessage);
@ -38,7 +28,7 @@ public class ChatGPTRequestBuilder {
} }
@NotNull @NotNull
private static List<ChatGPTRequest.Message> getContextMessages(List<? extends MessageContext> contextMessages) { private static List<ChatGPTRequest.Message> getContextMessages(final List<? extends MessageContext> contextMessages) {
return contextMessages.stream() return contextMessages.stream()
.map(ChatGPTRequestBuilder::mapContextMessages) .map(ChatGPTRequestBuilder::mapContextMessages)
.flatMap(List::stream) .flatMap(List::stream)
@ -46,7 +36,7 @@ public class ChatGPTRequestBuilder {
} }
@NotNull @NotNull
private static List<ChatGPTRequest.Message> mapContextMessages(MessageContext contextMessage) { private static List<ChatGPTRequest.Message> mapContextMessages(final MessageContext contextMessage) {
return switch (contextMessage) { return switch (contextMessage) {
case PreviousInteraction previousInteractions -> List.of( case PreviousInteraction previousInteractions -> List.of(
new ChatGPTRequest.Message("user", previousInteractions.question()), new ChatGPTRequest.Message("user", previousInteractions.question()),
@ -56,4 +46,14 @@ public class ChatGPTRequestBuilder {
List.of(new ChatGPTRequest.Message("assistant", replyInteractions.answer())); List.of(new ChatGPTRequest.Message("assistant", replyInteractions.answer()));
}; };
} }
public ChatGPTRequest contextRequest(final List<? extends MessageContext> contextMessages, final String message, final String systemMessage) {
final List<ChatGPTRequest.Message> messages = getMessages(contextMessages, message, systemMessage);
final String contextModel = contextMessages.size() <= 1 ? smallContextModel : bigContextModel;
return new ChatGPTRequest(
contextModel,
messages,
0.7f
);
}
} }

View file

@ -5,34 +5,35 @@ import de.hhhammer.dchat.bot.openai.models.ChatGPTRequest;
import de.hhhammer.dchat.bot.openai.models.ChatGPTResponse; import de.hhhammer.dchat.bot.openai.models.ChatGPTResponse;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.net.URI; import java.net.URI;
import java.net.http.HttpClient; import java.net.http.HttpClient;
import java.net.http.HttpRequest; import java.net.http.HttpRequest;
import java.net.http.HttpResponse; import java.net.http.HttpResponse;
import java.time.Duration; import java.time.Duration;
public class ChatGPTService { public final class ChatGPTService {
private static final String url = "https://api.openai.com/v1/chat/completions"; private static final String url = "https://api.openai.com/v1/chat/completions";
private final String apiKey; private final String apiKey;
private final HttpClient httpClient; private final HttpClient httpClient;
private final ObjectMapper mapper; private final ObjectMapper mapper;
public ChatGPTService(String apiKey, HttpClient httpClient, ObjectMapper mapper) { public ChatGPTService(final String apiKey, final HttpClient httpClient, final ObjectMapper mapper) {
this.apiKey = apiKey; this.apiKey = apiKey;
this.httpClient = httpClient; this.httpClient = httpClient;
this.mapper = mapper; this.mapper = mapper;
} }
public ChatGPTResponse submit(ChatGPTRequest chatGPTRequest) throws IOException, InterruptedException, ResponseException { public ChatGPTResponse submit(final ChatGPTRequest chatGPTRequest) throws IOException, InterruptedException, ResponseException {
var data = mapper.writeValueAsBytes(chatGPTRequest); final byte[] data = mapper.writeValueAsBytes(chatGPTRequest);
var request = HttpRequest.newBuilder(URI.create(url)) final HttpRequest request = HttpRequest.newBuilder(URI.create(url))
.POST(HttpRequest.BodyPublishers.ofByteArray(data)) .POST(HttpRequest.BodyPublishers.ofByteArray(data))
.setHeader("Content-Type", "application/json") .setHeader("Content-Type", "application/json")
.setHeader("Authorization", "Bearer " + this.apiKey) .setHeader("Authorization", "Bearer " + this.apiKey)
.timeout(Duration.ofMinutes(5)) .timeout(Duration.ofMinutes(5))
.build(); .build();
var responseStream = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); final HttpResponse<InputStream> responseStream = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
if (responseStream.statusCode() != 200) { if (responseStream.statusCode() != 200) {
throw new ResponseException("Response status code was not 200: " + responseStream.statusCode()); throw new ResponseException("Response status code was not 200: " + responseStream.statusCode());
} }

View file

@ -1,7 +1,7 @@
package de.hhhammer.dchat.bot.openai; package de.hhhammer.dchat.bot.openai;
public class ResponseException extends Exception { public final class ResponseException extends Exception {
public ResponseException(String message) { public ResponseException(final String message) {
super(message); super(message);
} }
} }

View file

@ -1,7 +1,7 @@
package de.hhhammer.dchat.db; package de.hhhammer.dchat.db;
public class DBException extends Exception { public final class DBException extends Exception {
public DBException(String message, Throwable cause) { public DBException(final String message, final Throwable cause) {
super(message, cause); super(message, cause);
} }
} }

View file

@ -0,0 +1,202 @@
package de.hhhammer.dchat.db;
import de.hhhammer.dchat.db.models.server.ServerConfig;
import de.hhhammer.dchat.db.models.server.ServerMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.sql.DataSource;
import java.sql.*;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Optional;
import java.util.stream.StreamSupport;
public final class PostgresServerDBService implements ServerDBService {
private static final Logger logger = LoggerFactory.getLogger(PostgresServerDBService.class);
private final DataSource dataSource;
public PostgresServerDBService(final DataSource dataSource) {
this.dataSource = dataSource;
}
@Override
public Optional<ServerConfig> getConfig(final String serverId) {
final var getServerConfig = """
SELECT * FROM server_configs WHERE server_id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, serverId);
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (final SQLException e) {
logger.error("Getting configuration for server with id: " + serverId, e);
} catch (final ResultSetIteratorException e) {
logger.error("Iterating over ServerConfig ResultSet for server with id: " + serverId, e);
}
return Optional.empty();
}
@Override
public List<ServerConfig> getAllConfigs() throws DBException {
final var getAllowedServerSql = """
SELECT * FROM server_configs
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getAllowedServerSql)
) {
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).toList();
} catch (final SQLException e) {
throw new DBException("Loading all configs", e);
} catch (final ResultSetIteratorException e) {
throw new DBException("Iterating over configs", e);
}
}
@Override
public Optional<ServerConfig> getConfigBy(final long id) throws DBException {
final var getServerConfig = """
SELECT * FROM server_configs WHERE id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (SQLException e) {
throw new DBException("Getting configuration with id: " + id, e);
} catch (ResultSetIteratorException e) {
throw new DBException("Iterating over ServerConfig ResultSet for id: " + id, e);
}
}
@Override
public void addConfig(final ServerConfig.NewServerConfig newServerConfig) throws DBException {
final var getServerConfig = """
INSERT INTO server_configs (server_id, system_message, rate_limit) VALUES (?,?,?)
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newServerConfig.serverId());
pstmt.setString(2, newServerConfig.systemMessage());
pstmt.setInt(3, newServerConfig.rateLimit());
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config added for server with id: " + newServerConfig.serverId());
}
} catch (final SQLException e) {
throw new DBException("Adding configuration to server with id: " + newServerConfig.serverId(), e);
}
}
@Override
public void updateConfig(final long id, final ServerConfig.NewServerConfig newServerConfig) throws DBException {
final var getServerConfig = """
UPDATE server_configs SET system_message = ?, rate_limit = ?, server_id = ? WHERE id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newServerConfig.systemMessage());
pstmt.setInt(2, newServerConfig.rateLimit());
pstmt.setString(3, newServerConfig.serverId());
pstmt.setLong(4, id);
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config update for server with id: " + newServerConfig.serverId());
}
} catch (final SQLException e) {
throw new DBException("Adding configuration to server with id: " + newServerConfig.serverId(), e);
}
}
@Override
public void deleteConfig(final long id) throws DBException {
final var getServerConfig = """
DELETE FROM server_configs WHERE id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config deleted for server with id: " + id);
}
} catch (final SQLException e) {
throw new DBException("Deleting configuration for server with id: " + id, e);
}
}
@Override
public int countMessagesInLastMinute(final String serverId) {
final var getServerConfig = """
SELECT count(*) FROM server_messages WHERE server_id = ? AND time <= ? and time >= ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, serverId);
final var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(1, ChronoUnit.MINUTES)));
final ResultSet resultSet = pstmt.executeQuery();
if (resultSet.next()) return resultSet.getInt(1);
} catch (final SQLException e) {
logger.error("Getting messages for server with id: " + serverId, e);
} catch (final ResultSetIteratorException e) {
logger.error("Iterating over ServerMessages ResultSet for server with id: " + serverId, e);
}
return Integer.MAX_VALUE;
}
@Override
public void addMessage(final ServerMessage.NewServerMessage serverMessage) {
final var getServerConfig = """
INSERT INTO server_messages (server_id, user_id, tokens) VALUES (?,?,?)
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, serverMessage.serverId());
pstmt.setLong(2, serverMessage.userId());
pstmt.setInt(3, serverMessage.tokens());
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No message added for server with id: " + serverMessage.serverId());
}
} catch (final SQLException e) {
logger.error("Adding message to server with id: " + serverMessage.serverId(), e);
}
}
@Override
public long tokensOfLast30Days(final String serverId) {
final var countTokensOfLast30Days = """
SELECT sum(tokens) FROM server_messages WHERE server_id = ? AND time < ? AND time >= ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
) {
pstmt.setString(1, serverId);
final var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(30, ChronoUnit.DAYS)));
final ResultSet resultSet = pstmt.executeQuery();
if (resultSet.next()) return resultSet.getLong(1);
} catch (final SQLException e) {
logger.error("Counting tokens of the last 30 days from server with id: " + serverId, e);
}
logger.error("No tokens found for server with id: " + serverId);
return 0;
}
}

View file

@ -0,0 +1,226 @@
package de.hhhammer.dchat.db;
import de.hhhammer.dchat.db.models.user.UserConfig;
import de.hhhammer.dchat.db.models.user.UserMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.sql.DataSource;
import java.sql.*;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Optional;
import java.util.stream.StreamSupport;
public final class PostgresUserDBService implements UserDBService {
private static final Logger logger = LoggerFactory.getLogger(PostgresUserDBService.class);
private final DataSource dataSource;
public PostgresUserDBService(final DataSource dataSource) {
this.dataSource = dataSource;
}
@Override
public Optional<UserConfig> getConfig(final String userId) {
final var getServerConfig = """
SELECT * FROM user_configs WHERE user_id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, userId);
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (final SQLException e) {
logger.error("Getting configuration for user with id: " + userId, e);
} catch (final ResultSetIteratorException e) {
logger.error("Iterating over ServerConfig ResultSet for user with id: " + userId, e);
}
return Optional.empty();
}
@Override
public Optional<UserConfig> getConfigBy(final long id) throws DBException {
final var getServerConfig = """
SELECT * FROM user_configs WHERE id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (final SQLException e) {
throw new DBException("Getting configuration id: " + id, e);
} catch (final ResultSetIteratorException e) {
throw new DBException("Iterating over UserConfig ResultSet with id: " + id, e);
}
}
@Override
public List<UserConfig> getAllConfigs() throws DBException {
final var getServerConfig = """
SELECT * FROM user_configs
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).toList();
} catch (final SQLException e) {
throw new DBException("Getting all configurations", e);
} catch (final ResultSetIteratorException e) {
throw new DBException("Iterating over all UserConfig ResultSet", e);
}
}
@Override
public void addConfig(UserConfig.NewUserConfig newUserConfig) throws DBException {
final var getServerConfig = """
INSERT INTO user_configs (user_id, system_message, context_length, rate_limit) VALUES (?,?,?,?)
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserConfig.userId());
pstmt.setString(2, newUserConfig.systemMessage());
pstmt.setInt(3, newUserConfig.contextLength());
pstmt.setInt(4, newUserConfig.rateLimit());
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config added for user with id: " + newUserConfig.userId());
}
} catch (final SQLException e) {
throw new DBException("Adding configuration for user with id: " + newUserConfig.userId(), e);
}
}
@Override
public void updateConfig(final long id, final UserConfig.NewUserConfig newUserConfig) throws DBException {
final var getServerConfig = """
UPDATE user_configs SET system_message = ?, context_length = ?, rate_limit = ?, user_id = ? WHERE id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserConfig.systemMessage());
pstmt.setInt(2, newUserConfig.rateLimit());
pstmt.setLong(3, newUserConfig.contextLength());
pstmt.setString(4, newUserConfig.userId());
pstmt.setLong(5, id);
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config update with id: " + id);
}
} catch (final SQLException e) {
throw new DBException("Updating configuration with id: " + id, e);
}
}
@Override
public void deleteConfig(final long id) throws DBException {
final var getServerConfig = """
DELETE FROM user_configs WHERE id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config deleted for user with id: " + id);
}
} catch (final SQLException e) {
throw new DBException("Deleting configuration with id: " + id, e);
}
}
@Override
public int countMessagesInLastMinute(final String userId) {
final var getServerConfig = """
SELECT count(*) FROM user_messages WHERE user_id = ? AND time <= ? and time >= ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, userId);
final var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(1, ChronoUnit.MINUTES)));
final ResultSet resultSet = pstmt.executeQuery();
if (resultSet.next()) return resultSet.getInt(1);
} catch (final SQLException e) {
logger.error("Getting messages for user with id: " + userId, e);
} catch (final ResultSetIteratorException e) {
logger.error("Iterating over ServerMessages ResultSet for user with id: " + userId, e);
}
return Integer.MAX_VALUE;
}
@Override
public void addMessage(final UserMessage.NewUserMessage newUserMessage) {
final var getServerConfig = """
INSERT INTO user_messages (user_id, question, answer, tokens) VALUES (?,?,?,?)
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserMessage.userId());
pstmt.setString(2, newUserMessage.question());
pstmt.setString(3, newUserMessage.answer());
pstmt.setInt(4, newUserMessage.tokens());
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No message added for user with id: " + newUserMessage.userId());
}
} catch (final SQLException e) {
logger.error("Adding message to user with id: " + newUserMessage.userId(), e);
}
}
@Override
public List<UserMessage> getLastMessages(final String userId, final int limit) {
final var getLastMessages = """
SELECT * FROM user_messages WHERE user_id = ? ORDER BY time DESC LIMIT ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getLastMessages)
) {
pstmt.setString(1, userId);
pstmt.setInt(2, limit);
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<UserMessage> iterable = () -> new ResultSetIterator<>(resultSet, new UserMessage.UserMessageResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).toList();
} catch (final SQLException e) {
logger.error("Fetching last messages for user whit id: " + userId, e);
} catch (final ResultSetIteratorException e) {
logger.error("Iterating over messages ResultSet from user with id: " + userId, e);
}
return List.of();
}
@Override
public long tokensOfLast30Days(final String userId) {
final var countTokensOfLast30Days = """
SELECT sum(tokens) FROM user_messages WHERE user_id = ? AND time < ? AND time >= ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
) {
pstmt.setString(1, userId);
final var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(30, ChronoUnit.DAYS)));
final ResultSet resultSet = pstmt.executeQuery();
if (resultSet.next()) return resultSet.getLong(1);
} catch (final SQLException e) {
logger.error("Counting tokens of the last 30 days from user with id: " + userId, e);
}
logger.error("No tokens found for user with id: " + userId);
return 0;
}
}

View file

@ -4,11 +4,11 @@ import java.sql.ResultSet;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.Iterator; import java.util.Iterator;
public class ResultSetIterator<T> implements Iterator<T> { public final class ResultSetIterator<T> implements Iterator<T> {
private final ResultSet resultSet; private final ResultSet resultSet;
private final ResultSetTransformer<T> transformer; private final ResultSetTransformer<T> transformer;
public ResultSetIterator(ResultSet resultSet, ResultSetTransformer<T> transformer) { public ResultSetIterator(final ResultSet resultSet, final ResultSetTransformer<T> transformer) {
this.resultSet = resultSet; this.resultSet = resultSet;
this.transformer = transformer; this.transformer = transformer;
} }

View file

@ -1,7 +1,7 @@
package de.hhhammer.dchat.db; package de.hhhammer.dchat.db;
public class ResultSetIteratorException extends RuntimeException { public final class ResultSetIteratorException extends RuntimeException {
public ResultSetIteratorException(Throwable cause) { public ResultSetIteratorException(final Throwable cause) {
super(cause); super(cause);
} }
} }

View file

@ -2,196 +2,26 @@ package de.hhhammer.dchat.db;
import de.hhhammer.dchat.db.models.server.ServerConfig; import de.hhhammer.dchat.db.models.server.ServerConfig;
import de.hhhammer.dchat.db.models.server.ServerMessage; import de.hhhammer.dchat.db.models.server.ServerMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.sql.DataSource;
import java.sql.*;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.stream.StreamSupport;
public class ServerDBService { public interface ServerDBService {
private static final Logger logger = LoggerFactory.getLogger(ServerDBService.class); Optional<ServerConfig> getConfig(String serverId);
private final DataSource dataSource;
public ServerDBService(DataSource dataSource) { List<ServerConfig> getAllConfigs() throws DBException;
this.dataSource = dataSource;
}
public Optional<ServerConfig> getConfig(String serverId) { Optional<ServerConfig> getConfigBy(long id) throws DBException;
var getServerConfig = """
SELECT * FROM server_configs WHERE server_id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, serverId);
ResultSet resultSet = pstmt.executeQuery();
Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (SQLException e) {
logger.error("Getting configuration for server with id: " + serverId, e);
} catch (ResultSetIteratorException e) {
logger.error("Iterating over ServerConfig ResultSet for server with id: " + serverId, e);
return Optional.empty();
}
return Optional.empty();
}
public List<ServerConfig> getAllConfigs() throws DBException { void addConfig(ServerConfig.NewServerConfig newServerConfig) throws DBException;
var getAllowedServerSql = """
SELECT * FROM server_configs
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getAllowedServerSql)
) {
ResultSet resultSet = pstmt.executeQuery();
Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).toList();
} catch (SQLException e) {
throw new DBException("Loading all configs", e);
} catch (ResultSetIteratorException e) {
throw new DBException("Iterating over configs", e);
}
}
public Optional<ServerConfig> getConfigBy(long id) throws DBException { void updateConfig(long id, ServerConfig.NewServerConfig newServerConfig) throws DBException;
var getServerConfig = """
SELECT * FROM server_configs WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
ResultSet resultSet = pstmt.executeQuery();
Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (SQLException e) {
throw new DBException("Getting configuration with id: " + id, e);
} catch (ResultSetIteratorException e) {
throw new DBException("Iterating over ServerConfig ResultSet for id: " + id, e);
}
}
public void addConfig(ServerConfig.NewServerConfig newServerConfig) throws DBException { void deleteConfig(long id) throws DBException;
var getServerConfig = """
INSERT INTO server_configs (server_id, system_message, rate_limit) VALUES (?,?,?)
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newServerConfig.serverId());
pstmt.setString(2, newServerConfig.systemMessage());
pstmt.setInt(3, newServerConfig.rateLimit());
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config added for server with id: " + newServerConfig.serverId());
}
} catch (SQLException e) {
throw new DBException("Adding configuration to server with id: " + newServerConfig.serverId(), e);
}
}
public void updateConfig(long id, ServerConfig.NewServerConfig newServerConfig) throws DBException { int countMessagesInLastMinute(String serverId);
var getServerConfig = """
UPDATE server_configs SET system_message = ?, rate_limit = ?, server_id = ? WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newServerConfig.systemMessage());
pstmt.setInt(2, newServerConfig.rateLimit());
pstmt.setString(3, newServerConfig.serverId());
pstmt.setLong(4, id);
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config update for server with id: " + newServerConfig.serverId());
}
} catch (SQLException e) {
throw new DBException("Adding configuration to server with id: " + newServerConfig.serverId(), e);
}
}
public void deleteConfig(long id) throws DBException { void addMessage(ServerMessage.NewServerMessage serverMessage);
var getServerConfig = """
DELETE FROM server_configs WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config deleted for server with id: " + id);
}
} catch (SQLException e) {
throw new DBException("Deleting configuration for server with id: " + id, e);
}
}
public int countMessagesInLastMinute(String serverId) { long tokensOfLast30Days(String serverId);
var getServerConfig = """
SELECT count(*) FROM server_messages WHERE server_id = ? AND time <= ? and time >= ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, serverId);
var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(1, ChronoUnit.MINUTES)));
ResultSet resultSet = pstmt.executeQuery();
resultSet.next();
return resultSet.getInt(1);
} catch (SQLException e) {
logger.error("Getting messages for server with id: " + serverId, e);
} catch (ResultSetIteratorException e) {
logger.error("Iterating over ServerMessages ResultSet for server with id: " + serverId, e);
return Integer.MAX_VALUE;
}
return Integer.MAX_VALUE;
}
public void addMessage(ServerMessage.NewServerMessage serverMessage) {
var getServerConfig = """
INSERT INTO server_messages (server_id, user_id, tokens) VALUES (?,?,?)
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, serverMessage.serverId());
pstmt.setLong(2, serverMessage.userId());
pstmt.setInt(3, serverMessage.tokens());
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No message added for server with id: " + serverMessage.serverId());
}
} catch (SQLException e) {
logger.error("Adding message to server with id: " + serverMessage.serverId(), e);
}
}
public long tokensOfLast30Days(String serverId) {
var countTokensOfLast30Days = """
SELECT sum(tokens) FROM server_messages WHERE server_id = ? AND time < ? AND time >= ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
) {
pstmt.setString(1, serverId);
var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(30, ChronoUnit.DAYS)));
ResultSet resultSet = pstmt.executeQuery();
resultSet.next();
return resultSet.getLong(1);
} catch (SQLException e) {
logger.error("Counting tokens of the last 30 days from server with id: " + serverId, e);
}
logger.error("No tokens found for server with id: " + serverId);
return 0;
}
} }

View file

@ -2,219 +2,28 @@ package de.hhhammer.dchat.db;
import de.hhhammer.dchat.db.models.user.UserConfig; import de.hhhammer.dchat.db.models.user.UserConfig;
import de.hhhammer.dchat.db.models.user.UserMessage; import de.hhhammer.dchat.db.models.user.UserMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.sql.DataSource;
import java.sql.*;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.stream.StreamSupport;
public class UserDBService { public interface UserDBService {
private static final Logger logger = LoggerFactory.getLogger(UserDBService.class); Optional<UserConfig> getConfig(String userId);
private final DataSource dataSource;
public UserDBService(DataSource dataSource) { Optional<UserConfig> getConfigBy(long id) throws DBException;
this.dataSource = dataSource;
}
public Optional<UserConfig> getConfig(String userId) { List<UserConfig> getAllConfigs() throws DBException;
var getServerConfig = """
SELECT * FROM user_configs WHERE user_id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, userId);
ResultSet resultSet = pstmt.executeQuery();
Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (SQLException e) {
logger.error("Getting configuration for user with id: " + userId, e);
} catch (ResultSetIteratorException e) {
logger.error("Iterating over ServerConfig ResultSet for user with id: " + userId, e);
return Optional.empty();
}
return Optional.empty();
}
public Optional<UserConfig> getConfigBy(long id) throws DBException { void addConfig(UserConfig.NewUserConfig newUserConfig) throws DBException;
var getServerConfig = """
SELECT * FROM user_configs WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
ResultSet resultSet = pstmt.executeQuery();
Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (SQLException e) {
throw new DBException("Getting configuration id: " + id, e);
} catch (ResultSetIteratorException e) {
throw new DBException("Iterating over UserConfig ResultSet with id: " + id, e);
}
}
public List<UserConfig> getAllConfigs() throws DBException { void updateConfig(long id, UserConfig.NewUserConfig newUserConfig) throws DBException;
var getServerConfig = """
SELECT * FROM user_configs
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
ResultSet resultSet = pstmt.executeQuery();
Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).toList();
} catch (SQLException e) {
throw new DBException("Getting all configurations", e);
} catch (ResultSetIteratorException e) {
throw new DBException("Iterating over all UserConfig ResultSet", e);
}
}
public void addConfig(UserConfig.NewUserConfig newUserConfig) throws DBException { void deleteConfig(long id) throws DBException;
var getServerConfig = """
INSERT INTO user_configs (user_id, system_message, context_length, rate_limit) VALUES (?,?,?,?)
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserConfig.userId());
pstmt.setString(2, newUserConfig.systemMessage());
pstmt.setInt(3, newUserConfig.contextLength());
pstmt.setInt(4, newUserConfig.rateLimit());
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config added for user with id: " + newUserConfig.userId());
}
} catch (SQLException e) {
throw new DBException("Adding configuration for user with id: " + newUserConfig.userId(), e);
}
}
public void updateConfig(long id, UserConfig.NewUserConfig newUserConfig) throws DBException { int countMessagesInLastMinute(String userId);
var getServerConfig = """
UPDATE user_configs SET system_message = ?, context_length = ?, rate_limit = ?, user_id = ? WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserConfig.systemMessage());
pstmt.setInt(2, newUserConfig.rateLimit());
pstmt.setLong(3, newUserConfig.contextLength());
pstmt.setString(4, newUserConfig.userId());
pstmt.setLong(5, id);
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config update with id: " + id);
}
} catch (SQLException e) {
throw new DBException("Updating configuration with id: " + id, e);
}
}
public void deleteConfig(long id) throws DBException { void addMessage(UserMessage.NewUserMessage newUserMessage);
var getServerConfig = """
DELETE FROM user_configs WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config deleted for user with id: " + id);
}
} catch (SQLException e) {
throw new DBException("Deleting configuration with id: " + id, e);
}
}
public int countMessagesInLastMinute(String userId) { List<UserMessage> getLastMessages(String userId, int limit);
var getServerConfig = """
SELECT count(*) FROM user_messages WHERE user_id = ? AND time <= ? and time >= ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, userId);
var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(1, ChronoUnit.MINUTES)));
ResultSet resultSet = pstmt.executeQuery();
resultSet.next();
return resultSet.getInt(1);
} catch (SQLException e) {
logger.error("Getting messages for user with id: " + userId, e);
} catch (ResultSetIteratorException e) {
logger.error("Iterating over ServerMessages ResultSet for user with id: " + userId, e);
return Integer.MAX_VALUE;
}
return Integer.MAX_VALUE;
}
public void addMessage(UserMessage.NewUserMessage newUserMessage) { long tokensOfLast30Days(String userId);
var getServerConfig = """
INSERT INTO user_messages (user_id, question, answer, tokens) VALUES (?,?,?,?)
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserMessage.userId());
pstmt.setString(2, newUserMessage.question());
pstmt.setString(3, newUserMessage.answer());
pstmt.setInt(4, newUserMessage.tokens());
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No message added for user with id: " + newUserMessage.userId());
}
} catch (SQLException e) {
logger.error("Adding message to user with id: " + newUserMessage.userId(), e);
}
}
public List<UserMessage> getLastMessages(String userId, int limit) {
var getLastMessages = """
SELECT * FROM user_messages WHERE user_id = ? ORDER BY time DESC LIMIT ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getLastMessages)
) {
pstmt.setString(1, userId);
pstmt.setInt(2, limit);
ResultSet resultSet = pstmt.executeQuery();
Iterable<UserMessage> iterable = () -> new ResultSetIterator<>(resultSet, new UserMessage.UserMessageResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).toList();
} catch (SQLException e) {
logger.error("Fetching last messages for user whit id: " + userId, e);
} catch (ResultSetIteratorException e) {
logger.error("Iterating over messages ResultSet from user with id: " + userId, e);
}
return List.of();
}
public long tokensOfLast30Days(String userId) {
var countTokensOfLast30Days = """
SELECT sum(tokens) FROM user_messages WHERE user_id = ? AND time < ? AND time >= ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
) {
pstmt.setString(1, userId);
var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(30, ChronoUnit.DAYS)));
ResultSet resultSet = pstmt.executeQuery();
resultSet.next();
return resultSet.getLong(1);
} catch (SQLException e) {
logger.error("Counting tokens of the last 30 days from user with id: " + userId, e);
}
logger.error("No tokens found for user with id: " + userId);
return 0;
}
} }

View file

@ -6,20 +6,20 @@ import org.slf4j.LoggerFactory;
/** /**
* Hello world! * Hello world!
*/ */
public class App { public final class App {
private static final Logger logger = LoggerFactory.getLogger(App.class); private static final Logger logger = LoggerFactory.getLogger(App.class);
private static final String DB_MIGRATION_PATH = "db/schema.sql"; private static final String DB_MIGRATION_PATH = "db/schema.sql";
public static void main(String[] args) { public static void main(final String[] args) {
String postgresUser = System.getenv("POSTGRES_USER"); final String postgresUser = System.getenv("POSTGRES_USER");
String postgresPassword = System.getenv("POSTGRES_PASSWORD"); final String postgresPassword = System.getenv("POSTGRES_PASSWORD");
String postgresUrl = System.getenv("POSTGRES_URL"); final String postgresUrl = System.getenv("POSTGRES_URL");
if (postgresUser == null || postgresPassword == null || postgresUrl == null) { if (postgresUser == null || postgresPassword == null || postgresUrl == null) {
logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL"); logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL");
System.exit(1); System.exit(1);
} }
var migrationExecutor = new MigrationExecutor(postgresUrl, postgresUser, postgresPassword); final var migrationExecutor = new MigrationExecutor(postgresUrl, postgresUser, postgresPassword);
var dbMigrator = new DBMigrator(migrationExecutor, DB_MIGRATION_PATH); final var dbMigrator = new DBMigrator(migrationExecutor, DB_MIGRATION_PATH);
dbMigrator.run(); dbMigrator.run();
} }
} }

View file

@ -1,7 +1,7 @@
package de.hhhammer.dchat.migration; package de.hhhammer.dchat.migration;
public class DBMigrationException extends Exception { public final class DBMigrationException extends Exception {
public DBMigrationException(Throwable cause) { public DBMigrationException(final Throwable cause) {
super(cause); super(cause);
} }
} }

View file

@ -6,12 +6,12 @@ import org.slf4j.LoggerFactory;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
public class DBMigrator implements Runnable { public final class DBMigrator implements Runnable {
private static final Logger logger = LoggerFactory.getLogger(DBMigrator.class); private static final Logger logger = LoggerFactory.getLogger(DBMigrator.class);
private final MigrationExecutor migrationExecutor; private final MigrationExecutor migrationExecutor;
private final String resourcePath; private final String resourcePath;
public DBMigrator(MigrationExecutor migrationExecutor, String resourcePath) { public DBMigrator(final MigrationExecutor migrationExecutor, final String resourcePath) {
this.migrationExecutor = migrationExecutor; this.migrationExecutor = migrationExecutor;
this.resourcePath = resourcePath; this.resourcePath = resourcePath;
} }
@ -19,8 +19,8 @@ public class DBMigrator implements Runnable {
@Override @Override
public void run() { public void run() {
logger.info("Starting db migration"); logger.info("Starting db migration");
ClassLoader classLoader = getClass().getClassLoader(); final ClassLoader classLoader = getClass().getClassLoader();
try (InputStream inputStream = classLoader.getResourceAsStream(this.resourcePath)) { try (final InputStream inputStream = classLoader.getResourceAsStream(this.resourcePath)) {
if (inputStream == null) { if (inputStream == null) {
logger.error("Migration file not found: " + resourcePath); logger.error("Migration file not found: " + resourcePath);
throw new RuntimeException("Migration file not found"); throw new RuntimeException("Migration file not found");

View file

@ -11,24 +11,24 @@ import java.sql.DriverManager;
import java.sql.SQLException; import java.sql.SQLException;
import java.sql.Statement; import java.sql.Statement;
public class MigrationExecutor { public final class MigrationExecutor {
private static final Logger logger = LoggerFactory.getLogger(MigrationExecutor.class); private static final Logger logger = LoggerFactory.getLogger(MigrationExecutor.class);
private final String jdbcConnectionString; private final String jdbcConnectionString;
private final String username; private final String username;
private final String password; private final String password;
public MigrationExecutor(String jdbcConnectionString, String username, String password) { public MigrationExecutor(final String jdbcConnectionString, final String username, final String password) {
this.jdbcConnectionString = jdbcConnectionString; this.jdbcConnectionString = jdbcConnectionString;
this.username = username; this.username = username;
this.password = password; this.password = password;
} }
public void migrate(InputStream input) throws DBMigrationException { public void migrate(final InputStream input) throws DBMigrationException {
try (Connection con = DriverManager try (final Connection con = DriverManager
.getConnection(this.jdbcConnectionString, this.username, this.password); .getConnection(this.jdbcConnectionString, this.username, this.password);
Statement stmp = con.createStatement(); final Statement stmp = con.createStatement();
) { ) {
String content = new String(input.readAllBytes(), StandardCharsets.UTF_8); final String content = new String(input.readAllBytes(), StandardCharsets.UTF_8);
stmp.execute(content); stmp.execute(content);
} catch (SQLException | IOException e) { } catch (SQLException | IOException e) {
throw new DBMigrationException(e); throw new DBMigrationException(e);

View file

@ -12,10 +12,6 @@
<version>1.0-SNAPSHOT</version> <version>1.0-SNAPSHOT</version>
<name>web</name> <name>web</name>
<properties>
<node.version>v18.16.0</node.version>
</properties>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>de.hhhammer.dchat</groupId> <groupId>de.hhhammer.dchat</groupId>

View file

@ -2,41 +2,41 @@ package de.hhhammer.dchat.web;
import com.zaxxer.hikari.HikariConfig; import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource; import com.zaxxer.hikari.HikariDataSource;
import de.hhhammer.dchat.db.ServerDBService; import de.hhhammer.dchat.db.PostgresServerDBService;
import de.hhhammer.dchat.db.UserDBService; import de.hhhammer.dchat.db.PostgresUserDBService;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
/** /**
* Hello world! * Hello world!
*/ */
public class App { public final class App {
private static final Logger logger = LoggerFactory.getLogger(App.class); private static final Logger logger = LoggerFactory.getLogger(App.class);
public static void main(String[] args) { public static void main(final String[] args) {
String postgresUser = System.getenv("POSTGRES_USER"); final String postgresUser = System.getenv("POSTGRES_USER");
String postgresPassword = System.getenv("POSTGRES_PASSWORD"); final String postgresPassword = System.getenv("POSTGRES_PASSWORD");
String postgresUrl = System.getenv("POSTGRES_URL"); final String postgresUrl = System.getenv("POSTGRES_URL");
if (postgresUser == null || postgresPassword == null || postgresUrl == null) { if (postgresUser == null || postgresPassword == null || postgresUrl == null) {
logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL"); logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL");
System.exit(1); System.exit(1);
} }
String apiPortStr = System.getenv("API_PORT") != null ? System.getenv("API_PORT") : "8080"; final String apiPortStr = System.getenv("API_PORT") != null ? System.getenv("API_PORT") : "8080";
int apiPort = Integer.parseInt(apiPortStr); final int apiPort = Integer.parseInt(apiPortStr);
boolean debug = "true".equals(System.getenv("API_DEBUG")); final boolean debug = "true".equals(System.getenv("API_DEBUG"));
var config = new HikariConfig(); final var config = new HikariConfig();
config.setJdbcUrl(postgresUrl); config.setJdbcUrl(postgresUrl);
config.setUsername(postgresUser); config.setUsername(postgresUser);
config.setPassword(postgresPassword); config.setPassword(postgresPassword);
try (var ds = new HikariDataSource(config)) { try (final var ds = new HikariDataSource(config)) {
var serverDBService = new ServerDBService(ds); final var serverDBService = new PostgresServerDBService(ds);
var userDBService = new UserDBService(ds); final var userDBService = new PostgresUserDBService(ds);
var appConfig = new AppConfig(apiPort, debug); final var appConfig = new AppConfig(apiPort, debug);
var webApi = new WebAPI(serverDBService, userDBService, appConfig); final var webApi = new WebAPI(serverDBService, userDBService, appConfig);
webApi.run(); webApi.run();
} }
} }

View file

@ -6,7 +6,6 @@ import de.hhhammer.dchat.web.server.ConfigCrudHandler;
import de.hhhammer.dchat.web.user.ConfigUserCrudHandler; import de.hhhammer.dchat.web.user.ConfigUserCrudHandler;
import io.javalin.Javalin; import io.javalin.Javalin;
import io.javalin.http.HttpStatus; import io.javalin.http.HttpStatus;
import io.javalin.http.staticfiles.Location;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -16,13 +15,13 @@ import java.util.concurrent.ExecutionException;
import static io.javalin.apibuilder.ApiBuilder.crud; import static io.javalin.apibuilder.ApiBuilder.crud;
import static io.javalin.apibuilder.ApiBuilder.path; import static io.javalin.apibuilder.ApiBuilder.path;
public class WebAPI implements Runnable { public final class WebAPI implements Runnable {
private static final Logger logger = LoggerFactory.getLogger(WebAPI.class); private static final Logger logger = LoggerFactory.getLogger(WebAPI.class);
private final ServerDBService serverDBService; private final ServerDBService serverDBService;
private final UserDBService userDBService; private final UserDBService userDBService;
private final AppConfig appConfig; private final AppConfig appConfig;
public WebAPI(ServerDBService serverDBService, UserDBService userDBService, AppConfig appConfig) { public WebAPI(final ServerDBService serverDBService, final UserDBService userDBService, final AppConfig appConfig) {
this.serverDBService = serverDBService; this.serverDBService = serverDBService;
this.userDBService = userDBService; this.userDBService = userDBService;
this.appConfig = appConfig; this.appConfig = appConfig;
@ -31,12 +30,12 @@ public class WebAPI implements Runnable {
@Override @Override
public void run() { public void run() {
logger.info("Starting web application"); logger.info("Starting web application");
var app = Javalin.create(config -> { final Javalin app = Javalin.create(config -> {
if (appConfig.debug()) config.plugins.enableDevLogging(); if (appConfig.debug()) config.plugins.enableDevLogging();
config.http.prefer405over404 = true; // return 405 instead of 404 if path is mapped to different HTTP method config.http.prefer405over404 = true; // return 405 instead of 404 if path is mapped to different HTTP method
config.http.defaultContentType = "application/json"; config.http.defaultContentType = "application/json";
}); });
var waitForShutdown = new CompletableFuture<Void>(); final var waitForShutdown = new CompletableFuture<Void>();
Runtime.getRuntime().addShutdownHook(new Thread(() -> { Runtime.getRuntime().addShutdownHook(new Thread(() -> {
logger.info("Shutting down web application"); logger.info("Shutting down web application");
app.stop(); app.stop();

View file

@ -10,18 +10,21 @@ import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
public class ConfigCrudHandler implements CrudHandler { import java.util.List;
import java.util.Optional;
public final class ConfigCrudHandler implements CrudHandler {
private static final Logger logger = LoggerFactory.getLogger(ConfigCrudHandler.class); private static final Logger logger = LoggerFactory.getLogger(ConfigCrudHandler.class);
private final ServerDBService serverDBService; private final ServerDBService serverDBService;
public ConfigCrudHandler(ServerDBService serverDBService) { public ConfigCrudHandler(final ServerDBService serverDBService) {
this.serverDBService = serverDBService; this.serverDBService = serverDBService;
} }
@Override @Override
public void create(@NotNull Context context) { public void create(@NotNull final Context context) {
var body = context.bodyAsClass(ServerConfig.NewServerConfig.class); final ServerConfig.NewServerConfig body = context.bodyAsClass(ServerConfig.NewServerConfig.class);
try { try {
this.serverDBService.addConfig(body); this.serverDBService.addConfig(body);
} catch (DBException e) { } catch (DBException e) {
@ -32,9 +35,10 @@ public class ConfigCrudHandler implements CrudHandler {
} }
@Override @Override
public void delete(@NotNull Context context, @NotNull String s) { public void delete(@NotNull final Context context, @NotNull final String s) {
final var id = Long.parseLong(s);
try { try {
this.serverDBService.deleteConfig(Long.parseLong(s)); this.serverDBService.deleteConfig(id);
context.status(HttpStatus.NO_CONTENT); context.status(HttpStatus.NO_CONTENT);
} catch (DBException e) { } catch (DBException e) {
logger.error("Deleting configuration with id: " + s, e); logger.error("Deleting configuration with id: " + s, e);
@ -43,9 +47,9 @@ public class ConfigCrudHandler implements CrudHandler {
} }
@Override @Override
public void getAll(@NotNull Context context) { public void getAll(@NotNull final Context context) {
try { try {
var allowedServers = this.serverDBService.getAllConfigs(); final List<ServerConfig> allowedServers = this.serverDBService.getAllConfigs();
context.json(allowedServers); context.json(allowedServers);
} catch (DBException e) { } catch (DBException e) {
logger.error("Getting all server configs", e); logger.error("Getting all server configs", e);
@ -54,10 +58,10 @@ public class ConfigCrudHandler implements CrudHandler {
} }
@Override @Override
public void getOne(@NotNull Context context, @NotNull String s) { public void getOne(@NotNull final Context context, @NotNull final String s) {
var id = Long.parseLong(s); final var id = Long.parseLong(s);
try { try {
var server = this.serverDBService.getConfigBy(id); final Optional<ServerConfig> server = this.serverDBService.getConfigBy(id);
if (server.isEmpty()) { if (server.isEmpty()) {
context.status(HttpStatus.NOT_FOUND); context.status(HttpStatus.NOT_FOUND);
return; return;
@ -70,9 +74,9 @@ public class ConfigCrudHandler implements CrudHandler {
} }
@Override @Override
public void update(@NotNull Context context, @NotNull String idString) { public void update(@NotNull final Context context, @NotNull final String idString) {
var body = context.bodyAsClass(ServerConfig.NewServerConfig.class); final ServerConfig.NewServerConfig body = context.bodyAsClass(ServerConfig.NewServerConfig.class);
var id = Long.parseLong(idString); final var id = Long.parseLong(idString);
try { try {
this.serverDBService.updateConfig(id, body); this.serverDBService.updateConfig(id, body);

View file

@ -10,18 +10,21 @@ import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
public class ConfigUserCrudHandler implements CrudHandler { import java.util.List;
import java.util.Optional;
public final class ConfigUserCrudHandler implements CrudHandler {
private static final Logger logger = LoggerFactory.getLogger(ConfigUserCrudHandler.class); private static final Logger logger = LoggerFactory.getLogger(ConfigUserCrudHandler.class);
private final UserDBService userDBService; private final UserDBService userDBService;
public ConfigUserCrudHandler(UserDBService userDBService) { public ConfigUserCrudHandler(final UserDBService userDBService) {
this.userDBService = userDBService; this.userDBService = userDBService;
} }
@Override @Override
public void create(@NotNull Context context) { public void create(@NotNull final Context context) {
var body = context.bodyAsClass(UserConfig.NewUserConfig.class); final UserConfig.NewUserConfig body = context.bodyAsClass(UserConfig.NewUserConfig.class);
try { try {
this.userDBService.addConfig(body); this.userDBService.addConfig(body);
} catch (DBException e) { } catch (DBException e) {
@ -33,9 +36,10 @@ public class ConfigUserCrudHandler implements CrudHandler {
} }
@Override @Override
public void delete(@NotNull Context context, @NotNull String s) { public void delete(@NotNull final Context context, @NotNull final String s) {
final var id = Long.parseLong(s);
try { try {
this.userDBService.deleteConfig(Long.parseLong(s)); this.userDBService.deleteConfig(id);
context.status(HttpStatus.NO_CONTENT); context.status(HttpStatus.NO_CONTENT);
} catch (DBException e) { } catch (DBException e) {
logger.error("Deleting configuration with id: " + s, e); logger.error("Deleting configuration with id: " + s, e);
@ -44,9 +48,9 @@ public class ConfigUserCrudHandler implements CrudHandler {
} }
@Override @Override
public void getAll(@NotNull Context context) { public void getAll(@NotNull final Context context) {
try { try {
var allowedServers = this.userDBService.getAllConfigs(); final List<UserConfig> allowedServers = this.userDBService.getAllConfigs();
context.json(allowedServers); context.json(allowedServers);
} catch (DBException e) { } catch (DBException e) {
logger.error("Getting all user configs", e); logger.error("Getting all user configs", e);
@ -55,10 +59,10 @@ public class ConfigUserCrudHandler implements CrudHandler {
} }
@Override @Override
public void getOne(@NotNull Context context, @NotNull String s) { public void getOne(@NotNull final Context context, @NotNull final String s) {
var id = Long.parseLong(s); final var id = Long.parseLong(s);
try { try {
var server = this.userDBService.getConfigBy(id); final Optional<UserConfig> server = this.userDBService.getConfigBy(id);
if (server.isEmpty()) { if (server.isEmpty()) {
context.status(HttpStatus.NOT_FOUND); context.status(HttpStatus.NOT_FOUND);
return; return;
@ -72,8 +76,8 @@ public class ConfigUserCrudHandler implements CrudHandler {
@Override @Override
public void update(@NotNull Context context, @NotNull String idString) { public void update(@NotNull Context context, @NotNull String idString) {
var body = context.bodyAsClass(UserConfig.NewUserConfig.class); final UserConfig.NewUserConfig body = context.bodyAsClass(UserConfig.NewUserConfig.class);
var id = Long.parseLong(idString); final var id = Long.parseLong(idString);
try { try {
this.userDBService.updateConfig(id, body); this.userDBService.updateConfig(id, body);