Compare commits

..

No commits in common. "168a5d6c818327cbf1d08964acaee4d5646cf9c6" and "5c67a47806c298746576d7b9f3e9fdadd39c335e" have entirely different histories.

24 changed files with 562 additions and 642 deletions

View file

@ -4,47 +4,47 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import de.hhhammer.dchat.bot.openai.ChatGPTService;
import de.hhhammer.dchat.db.PostgresServerDBService;
import de.hhhammer.dchat.db.PostgresUserDBService;
import de.hhhammer.dchat.db.ServerDBService;
import de.hhhammer.dchat.db.UserDBService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.net.http.HttpClient;
public final class App {
public class App {
private static final Logger logger = LoggerFactory.getLogger(App.class);
public static void main(final String[] args) {
final String discordApiKey = System.getenv("DISCORD_API_KEY");
public static void main(String[] args) {
String discordApiKey = System.getenv("DISCORD_API_KEY");
if (discordApiKey == null) {
logger.error("Missing environment variables: DISCORD_API_KEY");
System.exit(1);
}
final String openaiApiKey = System.getenv("OPENAI_API_KEY");
String openaiApiKey = System.getenv("OPENAI_API_KEY");
if (openaiApiKey == null) {
logger.error("Missing environment variables: OPENAI_API_KEY");
System.exit(1);
}
final String postgresUser = System.getenv("POSTGRES_USER");
final String postgresPassword = System.getenv("POSTGRES_PASSWORD");
final String postgresUrl = System.getenv("POSTGRES_URL");
String postgresUser = System.getenv("POSTGRES_USER");
String postgresPassword = System.getenv("POSTGRES_PASSWORD");
String postgresUrl = System.getenv("POSTGRES_URL");
if (postgresUser == null || postgresPassword == null || postgresUrl == null) {
logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL");
System.exit(1);
}
final var chatGPTService = new ChatGPTService(openaiApiKey, HttpClient.newHttpClient(), new ObjectMapper());
var chatGPTService = new ChatGPTService(openaiApiKey, HttpClient.newHttpClient(), new ObjectMapper());
final var config = new HikariConfig();
var config = new HikariConfig();
config.setJdbcUrl(postgresUrl);
config.setUsername(postgresUser);
config.setPassword(postgresPassword);
try (var ds = new HikariDataSource(config)) {
final var serverDBService = new PostgresServerDBService(ds);
final var userDBService = new PostgresUserDBService(ds);
var serverDBService = new ServerDBService(ds);
var userDBService = new UserDBService(ds);
final var discordBot = new DiscordBot(serverDBService, userDBService, chatGPTService, discordApiKey);
var discordBot = new DiscordBot(serverDBService, userDBService, chatGPTService, discordApiKey);
discordBot.run();
}
}

View file

@ -9,7 +9,6 @@ import de.hhhammer.dchat.db.UserDBService;
import org.javacord.api.DiscordApi;
import org.javacord.api.DiscordApiBuilder;
import org.javacord.api.interaction.SlashCommand;
import org.javacord.api.interaction.SlashCommandInteraction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -17,7 +16,7 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
public final class DiscordBot implements Runnable {
public class DiscordBot implements Runnable {
private static final Logger logger = LoggerFactory.getLogger(DiscordBot.class);
private final ServerDBService serverDBService;
@ -25,7 +24,7 @@ public final class DiscordBot implements Runnable {
private final ChatGPTService chatGPTService;
private final String discordApiKey;
public DiscordBot(final ServerDBService serverDBService, final UserDBService userDBService, final ChatGPTService chatGPTService, final String discordApiKey) {
public DiscordBot(ServerDBService serverDBService, UserDBService userDBService, ChatGPTService chatGPTService, String discordApiKey) {
this.serverDBService = serverDBService;
this.userDBService = userDBService;
this.chatGPTService = chatGPTService;
@ -35,29 +34,29 @@ public final class DiscordBot implements Runnable {
@Override
public void run() {
logger.info("Starting Discord application");
final DiscordApi discordApi = new DiscordApiBuilder()
DiscordApi discordApi = new DiscordApiBuilder()
.setToken(discordApiKey)
.login()
.join();
discordApi.setMessageCacheSize(10, 60*60);
final var future = new CompletableFuture<Void>();
var future = new CompletableFuture<Void>();
Runtime.getRuntime().addShutdownHook(Thread.ofVirtual().unstarted(() -> {
logger.info("Shutting down Discord application");
discordApi.disconnect().thenAccept(future::complete);
}));
final SlashCommand token = SlashCommand.with("tokens", "Check how many tokens where spend on this server")
var token = SlashCommand.with("tokens", "Check how many tokens where spend on this server")
.createGlobal(discordApi)
.join();
discordApi.addSlashCommandCreateListener(event -> {
logger.debug("Event? " + event.getSlashCommandInteraction().getFullCommandName());
final SlashCommandInteraction command = event.getSlashCommandInteraction();
var command = event.getSlashCommandInteraction();
if (token.getFullCommandNames().contains(command.getFullCommandName())) {
event.getInteraction()
.respondLater()
.orTimeout(30, TimeUnit.SECONDS)
.thenAccept((interactionOriginalResponseUpdater) -> {
final long tokens = event.getInteraction().getServer().isPresent() ?
var tokens = event.getInteraction().getServer().isPresent() ?
this.serverDBService.tokensOfLast30Days(String.valueOf(event.getInteraction().getServer().get().getId())) :
this.userDBService.tokensOfLast30Days(String.valueOf(event.getInteraction().getUser().getId()));
interactionOriginalResponseUpdater.setContent("" + tokens).update();

View file

@ -9,17 +9,17 @@ import org.slf4j.LoggerFactory;
import java.io.IOException;
public final class MessageCreateHandler implements MessageCreateListener {
public class MessageCreateHandler implements MessageCreateListener {
private static final Logger logger = LoggerFactory.getLogger(MessageCreateHandler.class);
private final MessageHandler messageHandler;
public MessageCreateHandler(final MessageHandler messageHandler) {
public MessageCreateHandler(MessageHandler messageHandler) {
this.messageHandler = messageHandler;
}
@Override
public void onMessageCreate(final MessageCreateEvent event) {
public void onMessageCreate(MessageCreateEvent event) {
Thread.ofVirtual().start(() -> {
if (!event.canYouReadContent() || event.getMessageAuthor().isBotUser() || !isNormalOrReplyMessageType(event)) {
return;
@ -36,15 +36,15 @@ public final class MessageCreateHandler implements MessageCreateListener {
}
try {
this.messageHandler.handle(event);
} catch (final ResponseException | IOException | InterruptedException e) {
} catch (ResponseException | IOException | InterruptedException e) {
logger.error("Reading a message from the listener", e);
event.getMessage().reply("Sorry but something went wrong :(");
}
});
}
private boolean isNormalOrReplyMessageType(final MessageCreateEvent event) {
final MessageType type = event.getMessage().getType();
private boolean isNormalOrReplyMessageType(MessageCreateEvent event) {
MessageType type = event.getMessage().getType();
return type == MessageType.NORMAL || type == MessageType.REPLY;
}
}

View file

@ -4,10 +4,7 @@ import de.hhhammer.dchat.bot.openai.ChatGPTRequestBuilder;
import de.hhhammer.dchat.bot.openai.ChatGPTService;
import de.hhhammer.dchat.bot.openai.MessageContext.ReplyInteraction;
import de.hhhammer.dchat.bot.openai.ResponseException;
import de.hhhammer.dchat.bot.openai.models.ChatGPTRequest;
import de.hhhammer.dchat.bot.openai.models.ChatGPTResponse;
import de.hhhammer.dchat.db.ServerDBService;
import de.hhhammer.dchat.db.models.server.ServerConfig;
import de.hhhammer.dchat.db.models.server.ServerMessage;
import org.javacord.api.entity.DiscordEntity;
import org.javacord.api.entity.message.Message;
@ -20,44 +17,43 @@ import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
public final class ServerMessageHandler implements MessageHandler {
public class ServerMessageHandler implements MessageHandler {
private static final Logger logger = LoggerFactory.getLogger(ServerMessageHandler.class);
private final ServerDBService serverDBService;
private final ChatGPTService chatGPTService;
public ServerMessageHandler(final ServerDBService serverDBService, final ChatGPTService chatGPTService) {
public ServerMessageHandler(ServerDBService serverDBService, ChatGPTService chatGPTService) {
this.serverDBService = serverDBService;
this.chatGPTService = chatGPTService;
}
@Override
public void handle(final MessageCreateEvent event) throws ResponseException, IOException, InterruptedException {
final String content = extractContent(event);
final long serverId = event.getServer().get().getId();
final String systemMessage = this.serverDBService.getConfig(String.valueOf(serverId)).get().systemMessage();
final List<ReplyInteraction> messageContext = event.getMessage().getType() == MessageType.REPLY ? getContextMessages(event) : List.of();
final ChatGPTRequest request = new ChatGPTRequestBuilder().contextRequest(messageContext, content, systemMessage);
final ChatGPTResponse response = this.chatGPTService.submit(request);
public void handle(MessageCreateEvent event) throws ResponseException, IOException, InterruptedException {
String content = extractContent(event);
var serverId = event.getServer().get().getId();
var systemMessage = this.serverDBService.getConfig(String.valueOf(serverId)).get().systemMessage();
List<ReplyInteraction> messageContext = event.getMessage().getType() == MessageType.REPLY ? getContextMessages(event) : List.of();
var request = new ChatGPTRequestBuilder().contextRequest(messageContext, content, systemMessage);
var response = this.chatGPTService.submit(request);
if (response.choices().isEmpty()) {
event.getMessage().reply("No response available");
return;
}
final String answer = response.choices().get(0).message().content();
var answer = response.choices().get(0).message().content();
logServerMessage(event, response.usage().totalTokens());
event.getMessage().reply(answer);
}
@Override
public boolean isAllowed(final MessageCreateEvent event) {
public boolean isAllowed(MessageCreateEvent event) {
if (event.getServer().isEmpty()) {
return false;
}
final long serverId = event.getServer().get().getId();
final Optional<ServerConfig> config = this.serverDBService.getConfig(String.valueOf(serverId));
var serverId = event.getServer().get().getId();
var config = this.serverDBService.getConfig(String.valueOf(serverId));
if (config.isEmpty()) {
logger.debug("Not allowed with id: " + serverId);
return false;
@ -66,39 +62,39 @@ public final class ServerMessageHandler implements MessageHandler {
}
@Override
public boolean exceedsRate(final MessageCreateEvent event) {
final String serverId = String.valueOf(event.getServer().get().getId());
final Optional<ServerConfig> config = this.serverDBService.getConfig(serverId);
public boolean exceedsRate(MessageCreateEvent event) {
var serverId = String.valueOf(event.getServer().get().getId());
var config = this.serverDBService.getConfig(serverId);
if (config.isEmpty()) {
logger.error("Missing configuration for server with id: " + serverId);
return true;
}
final int rateLimit = config.get().rateLimit();
final int countMessagesInLastMinute = this.serverDBService.countMessagesInLastMinute(serverId);
var rateLimit = config.get().rateLimit();
var countMessagesInLastMinute = this.serverDBService.countMessagesInLastMinute(serverId);
return countMessagesInLastMinute >= rateLimit;
}
@Override
public boolean canHandle(final MessageCreateEvent event) {
public boolean canHandle(MessageCreateEvent event) {
return event.isServerMessage();
}
private void logServerMessage(final MessageCreateEvent event, final int tokens) {
final long serverId = event.getServer().map(DiscordEntity::getId).get();
final long userId = event.getMessageAuthor().getId();
private void logServerMessage(MessageCreateEvent event, int tokens) {
var serverId = event.getServer().map(DiscordEntity::getId).get();
var userId = event.getMessageAuthor().getId();
final var serverMessage = new ServerMessage.NewServerMessage(String.valueOf(serverId), userId, tokens);
var serverMessage = new ServerMessage.NewServerMessage(String.valueOf(serverId), userId, tokens);
this.serverDBService.addMessage(serverMessage);
}
private String extractContent(final MessageCreateEvent event) {
final long ownId = event.getApi().getYourself().getId();
private String extractContent(MessageCreateEvent event) {
long ownId = event.getApi().getYourself().getId();
return event.getMessageContent().replaceFirst("<" + ownId + "> ", "");
}
@NotNull
private List<ReplyInteraction> getContextMessages(final MessageCreateEvent event) {
private List<ReplyInteraction> getContextMessages(MessageCreateEvent event) {
return event.getMessage()
.getMessageReference()
.map(MessageReference::getMessage)

View file

@ -4,10 +4,7 @@ import de.hhhammer.dchat.bot.openai.ChatGPTRequestBuilder;
import de.hhhammer.dchat.bot.openai.ChatGPTService;
import de.hhhammer.dchat.bot.openai.MessageContext.PreviousInteraction;
import de.hhhammer.dchat.bot.openai.ResponseException;
import de.hhhammer.dchat.bot.openai.models.ChatGPTRequest;
import de.hhhammer.dchat.bot.openai.models.ChatGPTResponse;
import de.hhhammer.dchat.db.UserDBService;
import de.hhhammer.dchat.db.models.user.UserConfig;
import de.hhhammer.dchat.db.models.user.UserMessage;
import org.javacord.api.event.message.MessageCreateEvent;
import org.slf4j.Logger;
@ -15,47 +12,46 @@ import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
public final class UserMessageHandler implements MessageHandler {
public class UserMessageHandler implements MessageHandler {
private static final Logger logger = LoggerFactory.getLogger(UserMessageHandler.class);
private final UserDBService userDBService;
private final ChatGPTService chatGPTService;
public UserMessageHandler(final UserDBService userDBService, final ChatGPTService chatGPTService) {
public UserMessageHandler(UserDBService userDBService, ChatGPTService chatGPTService) {
this.userDBService = userDBService;
this.chatGPTService = chatGPTService;
}
@Override
public void handle(final MessageCreateEvent event) throws ResponseException, IOException, InterruptedException {
final String content = event.getReadableMessageContent();
final String userId = String.valueOf(event.getMessageAuthor().getId());
final UserConfig config = this.userDBService.getConfig(userId).get();
final String systemMessage = config.systemMessage();
final List<PreviousInteraction> context = this.userDBService.getLastMessages(userId, config.contextLength())
public void handle(MessageCreateEvent event) throws ResponseException, IOException, InterruptedException {
String content = event.getReadableMessageContent();
var userId = String.valueOf(event.getMessageAuthor().getId());
var config = this.userDBService.getConfig(userId).get();
var systemMessage = config.systemMessage();
List<PreviousInteraction> context = this.userDBService.getLastMessages(userId, config.contextLength())
.stream()
.map(userMessage -> new PreviousInteraction(userMessage.question(), userMessage.answer()))
.toList();
final ChatGPTRequest request = new ChatGPTRequestBuilder().contextRequest(context, content, systemMessage);
final ChatGPTResponse response = this.chatGPTService.submit(request);
var request = new ChatGPTRequestBuilder().contextRequest(context, content, systemMessage);
var response = this.chatGPTService.submit(request);
if (response.choices().isEmpty()) {
event.getMessage().reply("No response available");
return;
}
final String answer = response.choices().get(0).message().content();
var answer = response.choices().get(0).message().content();
logUserMessage(event, content, answer, response.usage().totalTokens());
event.getChannel().sendMessage(answer);
}
@Override
public boolean isAllowed(final MessageCreateEvent event) {
public boolean isAllowed(MessageCreateEvent event) {
if (event.getServer().isPresent()) {
return false;
}
final long userId = event.getMessageAuthor().getId();
final Optional<UserConfig> config = this.userDBService.getConfig(String.valueOf(userId));
var userId = event.getMessageAuthor().getId();
var config = this.userDBService.getConfig(String.valueOf(userId));
if (config.isEmpty()) {
logger.debug("Not allowed with id: " + userId);
return false;
@ -64,28 +60,28 @@ public final class UserMessageHandler implements MessageHandler {
}
@Override
public boolean exceedsRate(final MessageCreateEvent event) {
final String userId = String.valueOf(event.getMessageAuthor().getId());
final Optional<UserConfig> config = this.userDBService.getConfig(userId);
public boolean exceedsRate(MessageCreateEvent event) {
var userId = String.valueOf(event.getMessageAuthor().getId());
var config = this.userDBService.getConfig(userId);
if (config.isEmpty()) {
logger.error("Missing configuration for userId with id: " + userId);
return true;
}
final int rateLimit = config.get().rateLimit();
final int countMessagesInLastMinute = this.userDBService.countMessagesInLastMinute(userId);
var rateLimit = config.get().rateLimit();
var countMessagesInLastMinute = this.userDBService.countMessagesInLastMinute(userId);
return countMessagesInLastMinute >= rateLimit;
}
@Override
public boolean canHandle(final MessageCreateEvent event) {
public boolean canHandle(MessageCreateEvent event) {
return event.isPrivateMessage();
}
private void logUserMessage(final MessageCreateEvent event, final String question, final String answer, final int tokens) {
final long userId = event.getMessageAuthor().getId();
private void logUserMessage(MessageCreateEvent event, String question, String answer, int tokens) {
var userId = event.getMessageAuthor().getId();
final UserMessage.NewUserMessage userMessage = new UserMessage.NewUserMessage(String.valueOf(userId), question, answer, tokens);
var userMessage = new UserMessage.NewUserMessage(String.valueOf(userId), question, answer, tokens);
this.userDBService.addMessage(userMessage);
}
}

View file

@ -8,19 +8,29 @@ import org.jetbrains.annotations.NotNull;
import java.util.ArrayList;
import java.util.List;
public final class ChatGPTRequestBuilder {
public class ChatGPTRequestBuilder {
private static final String smallContextModel = "gpt-3.5-turbo";
private static final String bigContextModel = "gpt-3.5-turbo-16k";
public ChatGPTRequestBuilder() {
}
public ChatGPTRequest contextRequest(List<? extends MessageContext> contextMessages, String message, String systemMessage) {
List<ChatGPTRequest.Message> messages = getMessages(contextMessages, message, systemMessage);
String contextModel = contextMessages.size() <= 1 ? smallContextModel : bigContextModel;
return new ChatGPTRequest(
contextModel,
messages,
0.7f
);
}
@NotNull
private static List<ChatGPTRequest.Message> getMessages(final List<? extends MessageContext> contextMessages, final String message, final String systemMessage) {
final ChatGPTRequest.Message systemMsg = new ChatGPTRequest.Message("system", systemMessage);
final List<ChatGPTRequest.Message> contextMsgs = getContextMessages(contextMessages);
final ChatGPTRequest.Message userMessage = new ChatGPTRequest.Message("user", message);
final List<ChatGPTRequest.Message> messages = new ArrayList<>();
private static List<ChatGPTRequest.Message> getMessages(List<? extends MessageContext> contextMessages, String message, String systemMessage) {
var systemMsg = new ChatGPTRequest.Message("system", systemMessage);
List<ChatGPTRequest.Message> contextMsgs = getContextMessages(contextMessages);
var userMessage = new ChatGPTRequest.Message("user", message);
List<ChatGPTRequest.Message> messages = new ArrayList<>();
messages.add(systemMsg);
messages.addAll(contextMsgs);
messages.add(userMessage);
@ -28,7 +38,7 @@ public final class ChatGPTRequestBuilder {
}
@NotNull
private static List<ChatGPTRequest.Message> getContextMessages(final List<? extends MessageContext> contextMessages) {
private static List<ChatGPTRequest.Message> getContextMessages(List<? extends MessageContext> contextMessages) {
return contextMessages.stream()
.map(ChatGPTRequestBuilder::mapContextMessages)
.flatMap(List::stream)
@ -36,7 +46,7 @@ public final class ChatGPTRequestBuilder {
}
@NotNull
private static List<ChatGPTRequest.Message> mapContextMessages(final MessageContext contextMessage) {
private static List<ChatGPTRequest.Message> mapContextMessages(MessageContext contextMessage) {
return switch (contextMessage) {
case PreviousInteraction previousInteractions -> List.of(
new ChatGPTRequest.Message("user", previousInteractions.question()),
@ -46,14 +56,4 @@ public final class ChatGPTRequestBuilder {
List.of(new ChatGPTRequest.Message("assistant", replyInteractions.answer()));
};
}
public ChatGPTRequest contextRequest(final List<? extends MessageContext> contextMessages, final String message, final String systemMessage) {
final List<ChatGPTRequest.Message> messages = getMessages(contextMessages, message, systemMessage);
final String contextModel = contextMessages.size() <= 1 ? smallContextModel : bigContextModel;
return new ChatGPTRequest(
contextModel,
messages,
0.7f
);
}
}

View file

@ -5,35 +5,34 @@ import de.hhhammer.dchat.bot.openai.models.ChatGPTRequest;
import de.hhhammer.dchat.bot.openai.models.ChatGPTResponse;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
public final class ChatGPTService {
public class ChatGPTService {
private static final String url = "https://api.openai.com/v1/chat/completions";
private final String apiKey;
private final HttpClient httpClient;
private final ObjectMapper mapper;
public ChatGPTService(final String apiKey, final HttpClient httpClient, final ObjectMapper mapper) {
public ChatGPTService(String apiKey, HttpClient httpClient, ObjectMapper mapper) {
this.apiKey = apiKey;
this.httpClient = httpClient;
this.mapper = mapper;
}
public ChatGPTResponse submit(final ChatGPTRequest chatGPTRequest) throws IOException, InterruptedException, ResponseException {
final byte[] data = mapper.writeValueAsBytes(chatGPTRequest);
final HttpRequest request = HttpRequest.newBuilder(URI.create(url))
public ChatGPTResponse submit(ChatGPTRequest chatGPTRequest) throws IOException, InterruptedException, ResponseException {
var data = mapper.writeValueAsBytes(chatGPTRequest);
var request = HttpRequest.newBuilder(URI.create(url))
.POST(HttpRequest.BodyPublishers.ofByteArray(data))
.setHeader("Content-Type", "application/json")
.setHeader("Authorization", "Bearer " + this.apiKey)
.timeout(Duration.ofMinutes(5))
.build();
final HttpResponse<InputStream> responseStream = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
var responseStream = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
if (responseStream.statusCode() != 200) {
throw new ResponseException("Response status code was not 200: " + responseStream.statusCode());
}

View file

@ -1,7 +1,7 @@
package de.hhhammer.dchat.bot.openai;
public final class ResponseException extends Exception {
public ResponseException(final String message) {
public class ResponseException extends Exception {
public ResponseException(String message) {
super(message);
}
}

View file

@ -1,7 +1,7 @@
package de.hhhammer.dchat.db;
public final class DBException extends Exception {
public DBException(final String message, final Throwable cause) {
public class DBException extends Exception {
public DBException(String message, Throwable cause) {
super(message, cause);
}
}

View file

@ -1,202 +0,0 @@
package de.hhhammer.dchat.db;
import de.hhhammer.dchat.db.models.server.ServerConfig;
import de.hhhammer.dchat.db.models.server.ServerMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.sql.DataSource;
import java.sql.*;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Optional;
import java.util.stream.StreamSupport;
public final class PostgresServerDBService implements ServerDBService {
private static final Logger logger = LoggerFactory.getLogger(PostgresServerDBService.class);
private final DataSource dataSource;
public PostgresServerDBService(final DataSource dataSource) {
this.dataSource = dataSource;
}
@Override
public Optional<ServerConfig> getConfig(final String serverId) {
final var getServerConfig = """
SELECT * FROM server_configs WHERE server_id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, serverId);
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (final SQLException e) {
logger.error("Getting configuration for server with id: " + serverId, e);
} catch (final ResultSetIteratorException e) {
logger.error("Iterating over ServerConfig ResultSet for server with id: " + serverId, e);
}
return Optional.empty();
}
@Override
public List<ServerConfig> getAllConfigs() throws DBException {
final var getAllowedServerSql = """
SELECT * FROM server_configs
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getAllowedServerSql)
) {
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).toList();
} catch (final SQLException e) {
throw new DBException("Loading all configs", e);
} catch (final ResultSetIteratorException e) {
throw new DBException("Iterating over configs", e);
}
}
@Override
public Optional<ServerConfig> getConfigBy(final long id) throws DBException {
final var getServerConfig = """
SELECT * FROM server_configs WHERE id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (SQLException e) {
throw new DBException("Getting configuration with id: " + id, e);
} catch (ResultSetIteratorException e) {
throw new DBException("Iterating over ServerConfig ResultSet for id: " + id, e);
}
}
@Override
public void addConfig(final ServerConfig.NewServerConfig newServerConfig) throws DBException {
final var getServerConfig = """
INSERT INTO server_configs (server_id, system_message, rate_limit) VALUES (?,?,?)
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newServerConfig.serverId());
pstmt.setString(2, newServerConfig.systemMessage());
pstmt.setInt(3, newServerConfig.rateLimit());
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config added for server with id: " + newServerConfig.serverId());
}
} catch (final SQLException e) {
throw new DBException("Adding configuration to server with id: " + newServerConfig.serverId(), e);
}
}
@Override
public void updateConfig(final long id, final ServerConfig.NewServerConfig newServerConfig) throws DBException {
final var getServerConfig = """
UPDATE server_configs SET system_message = ?, rate_limit = ?, server_id = ? WHERE id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newServerConfig.systemMessage());
pstmt.setInt(2, newServerConfig.rateLimit());
pstmt.setString(3, newServerConfig.serverId());
pstmt.setLong(4, id);
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config update for server with id: " + newServerConfig.serverId());
}
} catch (final SQLException e) {
throw new DBException("Adding configuration to server with id: " + newServerConfig.serverId(), e);
}
}
@Override
public void deleteConfig(final long id) throws DBException {
final var getServerConfig = """
DELETE FROM server_configs WHERE id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config deleted for server with id: " + id);
}
} catch (final SQLException e) {
throw new DBException("Deleting configuration for server with id: " + id, e);
}
}
@Override
public int countMessagesInLastMinute(final String serverId) {
final var getServerConfig = """
SELECT count(*) FROM server_messages WHERE server_id = ? AND time <= ? and time >= ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, serverId);
final var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(1, ChronoUnit.MINUTES)));
final ResultSet resultSet = pstmt.executeQuery();
if (resultSet.next()) return resultSet.getInt(1);
} catch (final SQLException e) {
logger.error("Getting messages for server with id: " + serverId, e);
} catch (final ResultSetIteratorException e) {
logger.error("Iterating over ServerMessages ResultSet for server with id: " + serverId, e);
}
return Integer.MAX_VALUE;
}
@Override
public void addMessage(final ServerMessage.NewServerMessage serverMessage) {
final var getServerConfig = """
INSERT INTO server_messages (server_id, user_id, tokens) VALUES (?,?,?)
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, serverMessage.serverId());
pstmt.setLong(2, serverMessage.userId());
pstmt.setInt(3, serverMessage.tokens());
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No message added for server with id: " + serverMessage.serverId());
}
} catch (final SQLException e) {
logger.error("Adding message to server with id: " + serverMessage.serverId(), e);
}
}
@Override
public long tokensOfLast30Days(final String serverId) {
final var countTokensOfLast30Days = """
SELECT sum(tokens) FROM server_messages WHERE server_id = ? AND time < ? AND time >= ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
) {
pstmt.setString(1, serverId);
final var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(30, ChronoUnit.DAYS)));
final ResultSet resultSet = pstmt.executeQuery();
if (resultSet.next()) return resultSet.getLong(1);
} catch (final SQLException e) {
logger.error("Counting tokens of the last 30 days from server with id: " + serverId, e);
}
logger.error("No tokens found for server with id: " + serverId);
return 0;
}
}

View file

@ -1,226 +0,0 @@
package de.hhhammer.dchat.db;
import de.hhhammer.dchat.db.models.user.UserConfig;
import de.hhhammer.dchat.db.models.user.UserMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.sql.DataSource;
import java.sql.*;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Optional;
import java.util.stream.StreamSupport;
public final class PostgresUserDBService implements UserDBService {
private static final Logger logger = LoggerFactory.getLogger(PostgresUserDBService.class);
private final DataSource dataSource;
public PostgresUserDBService(final DataSource dataSource) {
this.dataSource = dataSource;
}
@Override
public Optional<UserConfig> getConfig(final String userId) {
final var getServerConfig = """
SELECT * FROM user_configs WHERE user_id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, userId);
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (final SQLException e) {
logger.error("Getting configuration for user with id: " + userId, e);
} catch (final ResultSetIteratorException e) {
logger.error("Iterating over ServerConfig ResultSet for user with id: " + userId, e);
}
return Optional.empty();
}
@Override
public Optional<UserConfig> getConfigBy(final long id) throws DBException {
final var getServerConfig = """
SELECT * FROM user_configs WHERE id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (final SQLException e) {
throw new DBException("Getting configuration id: " + id, e);
} catch (final ResultSetIteratorException e) {
throw new DBException("Iterating over UserConfig ResultSet with id: " + id, e);
}
}
@Override
public List<UserConfig> getAllConfigs() throws DBException {
final var getServerConfig = """
SELECT * FROM user_configs
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).toList();
} catch (final SQLException e) {
throw new DBException("Getting all configurations", e);
} catch (final ResultSetIteratorException e) {
throw new DBException("Iterating over all UserConfig ResultSet", e);
}
}
@Override
public void addConfig(UserConfig.NewUserConfig newUserConfig) throws DBException {
final var getServerConfig = """
INSERT INTO user_configs (user_id, system_message, context_length, rate_limit) VALUES (?,?,?,?)
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserConfig.userId());
pstmt.setString(2, newUserConfig.systemMessage());
pstmt.setInt(3, newUserConfig.contextLength());
pstmt.setInt(4, newUserConfig.rateLimit());
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config added for user with id: " + newUserConfig.userId());
}
} catch (final SQLException e) {
throw new DBException("Adding configuration for user with id: " + newUserConfig.userId(), e);
}
}
@Override
public void updateConfig(final long id, final UserConfig.NewUserConfig newUserConfig) throws DBException {
final var getServerConfig = """
UPDATE user_configs SET system_message = ?, context_length = ?, rate_limit = ?, user_id = ? WHERE id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserConfig.systemMessage());
pstmt.setInt(2, newUserConfig.rateLimit());
pstmt.setLong(3, newUserConfig.contextLength());
pstmt.setString(4, newUserConfig.userId());
pstmt.setLong(5, id);
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config update with id: " + id);
}
} catch (final SQLException e) {
throw new DBException("Updating configuration with id: " + id, e);
}
}
@Override
public void deleteConfig(final long id) throws DBException {
final var getServerConfig = """
DELETE FROM user_configs WHERE id = ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config deleted for user with id: " + id);
}
} catch (final SQLException e) {
throw new DBException("Deleting configuration with id: " + id, e);
}
}
@Override
public int countMessagesInLastMinute(final String userId) {
final var getServerConfig = """
SELECT count(*) FROM user_messages WHERE user_id = ? AND time <= ? and time >= ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, userId);
final var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(1, ChronoUnit.MINUTES)));
final ResultSet resultSet = pstmt.executeQuery();
if (resultSet.next()) return resultSet.getInt(1);
} catch (final SQLException e) {
logger.error("Getting messages for user with id: " + userId, e);
} catch (final ResultSetIteratorException e) {
logger.error("Iterating over ServerMessages ResultSet for user with id: " + userId, e);
}
return Integer.MAX_VALUE;
}
@Override
public void addMessage(final UserMessage.NewUserMessage newUserMessage) {
final var getServerConfig = """
INSERT INTO user_messages (user_id, question, answer, tokens) VALUES (?,?,?,?)
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserMessage.userId());
pstmt.setString(2, newUserMessage.question());
pstmt.setString(3, newUserMessage.answer());
pstmt.setInt(4, newUserMessage.tokens());
final int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No message added for user with id: " + newUserMessage.userId());
}
} catch (final SQLException e) {
logger.error("Adding message to user with id: " + newUserMessage.userId(), e);
}
}
@Override
public List<UserMessage> getLastMessages(final String userId, final int limit) {
final var getLastMessages = """
SELECT * FROM user_messages WHERE user_id = ? ORDER BY time DESC LIMIT ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(getLastMessages)
) {
pstmt.setString(1, userId);
pstmt.setInt(2, limit);
final ResultSet resultSet = pstmt.executeQuery();
final Iterable<UserMessage> iterable = () -> new ResultSetIterator<>(resultSet, new UserMessage.UserMessageResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).toList();
} catch (final SQLException e) {
logger.error("Fetching last messages for user whit id: " + userId, e);
} catch (final ResultSetIteratorException e) {
logger.error("Iterating over messages ResultSet from user with id: " + userId, e);
}
return List.of();
}
@Override
public long tokensOfLast30Days(final String userId) {
final var countTokensOfLast30Days = """
SELECT sum(tokens) FROM user_messages WHERE user_id = ? AND time < ? AND time >= ?
""";
try (final Connection con = dataSource.getConnection();
final PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
) {
pstmt.setString(1, userId);
final var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(30, ChronoUnit.DAYS)));
final ResultSet resultSet = pstmt.executeQuery();
if (resultSet.next()) return resultSet.getLong(1);
} catch (final SQLException e) {
logger.error("Counting tokens of the last 30 days from user with id: " + userId, e);
}
logger.error("No tokens found for user with id: " + userId);
return 0;
}
}

View file

@ -4,11 +4,11 @@ import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Iterator;
public final class ResultSetIterator<T> implements Iterator<T> {
public class ResultSetIterator<T> implements Iterator<T> {
private final ResultSet resultSet;
private final ResultSetTransformer<T> transformer;
public ResultSetIterator(final ResultSet resultSet, final ResultSetTransformer<T> transformer) {
public ResultSetIterator(ResultSet resultSet, ResultSetTransformer<T> transformer) {
this.resultSet = resultSet;
this.transformer = transformer;
}

View file

@ -1,7 +1,7 @@
package de.hhhammer.dchat.db;
public final class ResultSetIteratorException extends RuntimeException {
public ResultSetIteratorException(final Throwable cause) {
public class ResultSetIteratorException extends RuntimeException {
public ResultSetIteratorException(Throwable cause) {
super(cause);
}
}

View file

@ -2,26 +2,196 @@ package de.hhhammer.dchat.db;
import de.hhhammer.dchat.db.models.server.ServerConfig;
import de.hhhammer.dchat.db.models.server.ServerMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.sql.DataSource;
import java.sql.*;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Optional;
import java.util.stream.StreamSupport;
public interface ServerDBService {
Optional<ServerConfig> getConfig(String serverId);
public class ServerDBService {
private static final Logger logger = LoggerFactory.getLogger(ServerDBService.class);
private final DataSource dataSource;
List<ServerConfig> getAllConfigs() throws DBException;
Optional<ServerConfig> getConfigBy(long id) throws DBException;
void addConfig(ServerConfig.NewServerConfig newServerConfig) throws DBException;
void updateConfig(long id, ServerConfig.NewServerConfig newServerConfig) throws DBException;
void deleteConfig(long id) throws DBException;
int countMessagesInLastMinute(String serverId);
void addMessage(ServerMessage.NewServerMessage serverMessage);
long tokensOfLast30Days(String serverId);
public ServerDBService(DataSource dataSource) {
this.dataSource = dataSource;
}
public Optional<ServerConfig> getConfig(String serverId) {
var getServerConfig = """
SELECT * FROM server_configs WHERE server_id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, serverId);
ResultSet resultSet = pstmt.executeQuery();
Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (SQLException e) {
logger.error("Getting configuration for server with id: " + serverId, e);
} catch (ResultSetIteratorException e) {
logger.error("Iterating over ServerConfig ResultSet for server with id: " + serverId, e);
return Optional.empty();
}
return Optional.empty();
}
public List<ServerConfig> getAllConfigs() throws DBException {
var getAllowedServerSql = """
SELECT * FROM server_configs
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getAllowedServerSql)
) {
ResultSet resultSet = pstmt.executeQuery();
Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).toList();
} catch (SQLException e) {
throw new DBException("Loading all configs", e);
} catch (ResultSetIteratorException e) {
throw new DBException("Iterating over configs", e);
}
}
public Optional<ServerConfig> getConfigBy(long id) throws DBException {
var getServerConfig = """
SELECT * FROM server_configs WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
ResultSet resultSet = pstmt.executeQuery();
Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (SQLException e) {
throw new DBException("Getting configuration with id: " + id, e);
} catch (ResultSetIteratorException e) {
throw new DBException("Iterating over ServerConfig ResultSet for id: " + id, e);
}
}
public void addConfig(ServerConfig.NewServerConfig newServerConfig) throws DBException {
var getServerConfig = """
INSERT INTO server_configs (server_id, system_message, rate_limit) VALUES (?,?,?)
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newServerConfig.serverId());
pstmt.setString(2, newServerConfig.systemMessage());
pstmt.setInt(3, newServerConfig.rateLimit());
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config added for server with id: " + newServerConfig.serverId());
}
} catch (SQLException e) {
throw new DBException("Adding configuration to server with id: " + newServerConfig.serverId(), e);
}
}
public void updateConfig(long id, ServerConfig.NewServerConfig newServerConfig) throws DBException {
var getServerConfig = """
UPDATE server_configs SET system_message = ?, rate_limit = ?, server_id = ? WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newServerConfig.systemMessage());
pstmt.setInt(2, newServerConfig.rateLimit());
pstmt.setString(3, newServerConfig.serverId());
pstmt.setLong(4, id);
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config update for server with id: " + newServerConfig.serverId());
}
} catch (SQLException e) {
throw new DBException("Adding configuration to server with id: " + newServerConfig.serverId(), e);
}
}
public void deleteConfig(long id) throws DBException {
var getServerConfig = """
DELETE FROM server_configs WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config deleted for server with id: " + id);
}
} catch (SQLException e) {
throw new DBException("Deleting configuration for server with id: " + id, e);
}
}
public int countMessagesInLastMinute(String serverId) {
var getServerConfig = """
SELECT count(*) FROM server_messages WHERE server_id = ? AND time <= ? and time >= ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, serverId);
var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(1, ChronoUnit.MINUTES)));
ResultSet resultSet = pstmt.executeQuery();
resultSet.next();
return resultSet.getInt(1);
} catch (SQLException e) {
logger.error("Getting messages for server with id: " + serverId, e);
} catch (ResultSetIteratorException e) {
logger.error("Iterating over ServerMessages ResultSet for server with id: " + serverId, e);
return Integer.MAX_VALUE;
}
return Integer.MAX_VALUE;
}
public void addMessage(ServerMessage.NewServerMessage serverMessage) {
var getServerConfig = """
INSERT INTO server_messages (server_id, user_id, tokens) VALUES (?,?,?)
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, serverMessage.serverId());
pstmt.setLong(2, serverMessage.userId());
pstmt.setInt(3, serverMessage.tokens());
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No message added for server with id: " + serverMessage.serverId());
}
} catch (SQLException e) {
logger.error("Adding message to server with id: " + serverMessage.serverId(), e);
}
}
public long tokensOfLast30Days(String serverId) {
var countTokensOfLast30Days = """
SELECT sum(tokens) FROM server_messages WHERE server_id = ? AND time < ? AND time >= ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
) {
pstmt.setString(1, serverId);
var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(30, ChronoUnit.DAYS)));
ResultSet resultSet = pstmt.executeQuery();
resultSet.next();
return resultSet.getLong(1);
} catch (SQLException e) {
logger.error("Counting tokens of the last 30 days from server with id: " + serverId, e);
}
logger.error("No tokens found for server with id: " + serverId);
return 0;
}
}

View file

@ -2,28 +2,219 @@ package de.hhhammer.dchat.db;
import de.hhhammer.dchat.db.models.user.UserConfig;
import de.hhhammer.dchat.db.models.user.UserMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.sql.DataSource;
import java.sql.*;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Optional;
import java.util.stream.StreamSupport;
public interface UserDBService {
Optional<UserConfig> getConfig(String userId);
public class UserDBService {
private static final Logger logger = LoggerFactory.getLogger(UserDBService.class);
private final DataSource dataSource;
Optional<UserConfig> getConfigBy(long id) throws DBException;
List<UserConfig> getAllConfigs() throws DBException;
void addConfig(UserConfig.NewUserConfig newUserConfig) throws DBException;
void updateConfig(long id, UserConfig.NewUserConfig newUserConfig) throws DBException;
void deleteConfig(long id) throws DBException;
int countMessagesInLastMinute(String userId);
void addMessage(UserMessage.NewUserMessage newUserMessage);
List<UserMessage> getLastMessages(String userId, int limit);
long tokensOfLast30Days(String userId);
public UserDBService(DataSource dataSource) {
this.dataSource = dataSource;
}
public Optional<UserConfig> getConfig(String userId) {
var getServerConfig = """
SELECT * FROM user_configs WHERE user_id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, userId);
ResultSet resultSet = pstmt.executeQuery();
Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (SQLException e) {
logger.error("Getting configuration for user with id: " + userId, e);
} catch (ResultSetIteratorException e) {
logger.error("Iterating over ServerConfig ResultSet for user with id: " + userId, e);
return Optional.empty();
}
return Optional.empty();
}
public Optional<UserConfig> getConfigBy(long id) throws DBException {
var getServerConfig = """
SELECT * FROM user_configs WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
ResultSet resultSet = pstmt.executeQuery();
Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
} catch (SQLException e) {
throw new DBException("Getting configuration id: " + id, e);
} catch (ResultSetIteratorException e) {
throw new DBException("Iterating over UserConfig ResultSet with id: " + id, e);
}
}
public List<UserConfig> getAllConfigs() throws DBException {
var getServerConfig = """
SELECT * FROM user_configs
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
ResultSet resultSet = pstmt.executeQuery();
Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).toList();
} catch (SQLException e) {
throw new DBException("Getting all configurations", e);
} catch (ResultSetIteratorException e) {
throw new DBException("Iterating over all UserConfig ResultSet", e);
}
}
public void addConfig(UserConfig.NewUserConfig newUserConfig) throws DBException {
var getServerConfig = """
INSERT INTO user_configs (user_id, system_message, context_length, rate_limit) VALUES (?,?,?,?)
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserConfig.userId());
pstmt.setString(2, newUserConfig.systemMessage());
pstmt.setInt(3, newUserConfig.contextLength());
pstmt.setInt(4, newUserConfig.rateLimit());
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config added for user with id: " + newUserConfig.userId());
}
} catch (SQLException e) {
throw new DBException("Adding configuration for user with id: " + newUserConfig.userId(), e);
}
}
public void updateConfig(long id, UserConfig.NewUserConfig newUserConfig) throws DBException {
var getServerConfig = """
UPDATE user_configs SET system_message = ?, context_length = ?, rate_limit = ?, user_id = ? WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserConfig.systemMessage());
pstmt.setInt(2, newUserConfig.rateLimit());
pstmt.setLong(3, newUserConfig.contextLength());
pstmt.setString(4, newUserConfig.userId());
pstmt.setLong(5, id);
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config update with id: " + id);
}
} catch (SQLException e) {
throw new DBException("Updating configuration with id: " + id, e);
}
}
public void deleteConfig(long id) throws DBException {
var getServerConfig = """
DELETE FROM user_configs WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No config deleted for user with id: " + id);
}
} catch (SQLException e) {
throw new DBException("Deleting configuration with id: " + id, e);
}
}
public int countMessagesInLastMinute(String userId) {
var getServerConfig = """
SELECT count(*) FROM user_messages WHERE user_id = ? AND time <= ? and time >= ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, userId);
var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(1, ChronoUnit.MINUTES)));
ResultSet resultSet = pstmt.executeQuery();
resultSet.next();
return resultSet.getInt(1);
} catch (SQLException e) {
logger.error("Getting messages for user with id: " + userId, e);
} catch (ResultSetIteratorException e) {
logger.error("Iterating over ServerMessages ResultSet for user with id: " + userId, e);
return Integer.MAX_VALUE;
}
return Integer.MAX_VALUE;
}
public void addMessage(UserMessage.NewUserMessage newUserMessage) {
var getServerConfig = """
INSERT INTO user_messages (user_id, question, answer, tokens) VALUES (?,?,?,?)
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserMessage.userId());
pstmt.setString(2, newUserMessage.question());
pstmt.setString(3, newUserMessage.answer());
pstmt.setInt(4, newUserMessage.tokens());
int affectedRows = pstmt.executeUpdate();
if (affectedRows == 0) {
logger.error("No message added for user with id: " + newUserMessage.userId());
}
} catch (SQLException e) {
logger.error("Adding message to user with id: " + newUserMessage.userId(), e);
}
}
public List<UserMessage> getLastMessages(String userId, int limit) {
var getLastMessages = """
SELECT * FROM user_messages WHERE user_id = ? ORDER BY time DESC LIMIT ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getLastMessages)
) {
pstmt.setString(1, userId);
pstmt.setInt(2, limit);
ResultSet resultSet = pstmt.executeQuery();
Iterable<UserMessage> iterable = () -> new ResultSetIterator<>(resultSet, new UserMessage.UserMessageResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).toList();
} catch (SQLException e) {
logger.error("Fetching last messages for user whit id: " + userId, e);
} catch (ResultSetIteratorException e) {
logger.error("Iterating over messages ResultSet from user with id: " + userId, e);
}
return List.of();
}
public long tokensOfLast30Days(String userId) {
var countTokensOfLast30Days = """
SELECT sum(tokens) FROM user_messages WHERE user_id = ? AND time < ? AND time >= ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
) {
pstmt.setString(1, userId);
var now = Instant.now();
pstmt.setTimestamp(2, Timestamp.from(now));
pstmt.setTimestamp(3, Timestamp.from(now.minus(30, ChronoUnit.DAYS)));
ResultSet resultSet = pstmt.executeQuery();
resultSet.next();
return resultSet.getLong(1);
} catch (SQLException e) {
logger.error("Counting tokens of the last 30 days from user with id: " + userId, e);
}
logger.error("No tokens found for user with id: " + userId);
return 0;
}
}

View file

@ -6,20 +6,20 @@ import org.slf4j.LoggerFactory;
/**
* Hello world!
*/
public final class App {
public class App {
private static final Logger logger = LoggerFactory.getLogger(App.class);
private static final String DB_MIGRATION_PATH = "db/schema.sql";
public static void main(final String[] args) {
final String postgresUser = System.getenv("POSTGRES_USER");
final String postgresPassword = System.getenv("POSTGRES_PASSWORD");
final String postgresUrl = System.getenv("POSTGRES_URL");
public static void main(String[] args) {
String postgresUser = System.getenv("POSTGRES_USER");
String postgresPassword = System.getenv("POSTGRES_PASSWORD");
String postgresUrl = System.getenv("POSTGRES_URL");
if (postgresUser == null || postgresPassword == null || postgresUrl == null) {
logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL");
System.exit(1);
}
final var migrationExecutor = new MigrationExecutor(postgresUrl, postgresUser, postgresPassword);
final var dbMigrator = new DBMigrator(migrationExecutor, DB_MIGRATION_PATH);
var migrationExecutor = new MigrationExecutor(postgresUrl, postgresUser, postgresPassword);
var dbMigrator = new DBMigrator(migrationExecutor, DB_MIGRATION_PATH);
dbMigrator.run();
}
}

View file

@ -1,7 +1,7 @@
package de.hhhammer.dchat.migration;
public final class DBMigrationException extends Exception {
public DBMigrationException(final Throwable cause) {
public class DBMigrationException extends Exception {
public DBMigrationException(Throwable cause) {
super(cause);
}
}

View file

@ -6,12 +6,12 @@ import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.io.InputStream;
public final class DBMigrator implements Runnable {
public class DBMigrator implements Runnable {
private static final Logger logger = LoggerFactory.getLogger(DBMigrator.class);
private final MigrationExecutor migrationExecutor;
private final String resourcePath;
public DBMigrator(final MigrationExecutor migrationExecutor, final String resourcePath) {
public DBMigrator(MigrationExecutor migrationExecutor, String resourcePath) {
this.migrationExecutor = migrationExecutor;
this.resourcePath = resourcePath;
}
@ -19,8 +19,8 @@ public final class DBMigrator implements Runnable {
@Override
public void run() {
logger.info("Starting db migration");
final ClassLoader classLoader = getClass().getClassLoader();
try (final InputStream inputStream = classLoader.getResourceAsStream(this.resourcePath)) {
ClassLoader classLoader = getClass().getClassLoader();
try (InputStream inputStream = classLoader.getResourceAsStream(this.resourcePath)) {
if (inputStream == null) {
logger.error("Migration file not found: " + resourcePath);
throw new RuntimeException("Migration file not found");

View file

@ -11,24 +11,24 @@ import java.sql.DriverManager;
import java.sql.SQLException;
import java.sql.Statement;
public final class MigrationExecutor {
public class MigrationExecutor {
private static final Logger logger = LoggerFactory.getLogger(MigrationExecutor.class);
private final String jdbcConnectionString;
private final String username;
private final String password;
public MigrationExecutor(final String jdbcConnectionString, final String username, final String password) {
public MigrationExecutor(String jdbcConnectionString, String username, String password) {
this.jdbcConnectionString = jdbcConnectionString;
this.username = username;
this.password = password;
}
public void migrate(final InputStream input) throws DBMigrationException {
try (final Connection con = DriverManager
public void migrate(InputStream input) throws DBMigrationException {
try (Connection con = DriverManager
.getConnection(this.jdbcConnectionString, this.username, this.password);
final Statement stmp = con.createStatement();
Statement stmp = con.createStatement();
) {
final String content = new String(input.readAllBytes(), StandardCharsets.UTF_8);
String content = new String(input.readAllBytes(), StandardCharsets.UTF_8);
stmp.execute(content);
} catch (SQLException | IOException e) {
throw new DBMigrationException(e);

View file

@ -12,6 +12,10 @@
<version>1.0-SNAPSHOT</version>
<name>web</name>
<properties>
<node.version>v18.16.0</node.version>
</properties>
<dependencies>
<dependency>
<groupId>de.hhhammer.dchat</groupId>

View file

@ -2,41 +2,41 @@ package de.hhhammer.dchat.web;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import de.hhhammer.dchat.db.PostgresServerDBService;
import de.hhhammer.dchat.db.PostgresUserDBService;
import de.hhhammer.dchat.db.ServerDBService;
import de.hhhammer.dchat.db.UserDBService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Hello world!
*/
public final class App {
public class App {
private static final Logger logger = LoggerFactory.getLogger(App.class);
public static void main(final String[] args) {
final String postgresUser = System.getenv("POSTGRES_USER");
final String postgresPassword = System.getenv("POSTGRES_PASSWORD");
final String postgresUrl = System.getenv("POSTGRES_URL");
public static void main(String[] args) {
String postgresUser = System.getenv("POSTGRES_USER");
String postgresPassword = System.getenv("POSTGRES_PASSWORD");
String postgresUrl = System.getenv("POSTGRES_URL");
if (postgresUser == null || postgresPassword == null || postgresUrl == null) {
logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL");
System.exit(1);
}
final String apiPortStr = System.getenv("API_PORT") != null ? System.getenv("API_PORT") : "8080";
final int apiPort = Integer.parseInt(apiPortStr);
final boolean debug = "true".equals(System.getenv("API_DEBUG"));
String apiPortStr = System.getenv("API_PORT") != null ? System.getenv("API_PORT") : "8080";
int apiPort = Integer.parseInt(apiPortStr);
boolean debug = "true".equals(System.getenv("API_DEBUG"));
final var config = new HikariConfig();
var config = new HikariConfig();
config.setJdbcUrl(postgresUrl);
config.setUsername(postgresUser);
config.setPassword(postgresPassword);
try (final var ds = new HikariDataSource(config)) {
final var serverDBService = new PostgresServerDBService(ds);
final var userDBService = new PostgresUserDBService(ds);
final var appConfig = new AppConfig(apiPort, debug);
try (var ds = new HikariDataSource(config)) {
var serverDBService = new ServerDBService(ds);
var userDBService = new UserDBService(ds);
var appConfig = new AppConfig(apiPort, debug);
final var webApi = new WebAPI(serverDBService, userDBService, appConfig);
var webApi = new WebAPI(serverDBService, userDBService, appConfig);
webApi.run();
}
}

View file

@ -6,6 +6,7 @@ import de.hhhammer.dchat.web.server.ConfigCrudHandler;
import de.hhhammer.dchat.web.user.ConfigUserCrudHandler;
import io.javalin.Javalin;
import io.javalin.http.HttpStatus;
import io.javalin.http.staticfiles.Location;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -15,13 +16,13 @@ import java.util.concurrent.ExecutionException;
import static io.javalin.apibuilder.ApiBuilder.crud;
import static io.javalin.apibuilder.ApiBuilder.path;
public final class WebAPI implements Runnable {
public class WebAPI implements Runnable {
private static final Logger logger = LoggerFactory.getLogger(WebAPI.class);
private final ServerDBService serverDBService;
private final UserDBService userDBService;
private final AppConfig appConfig;
public WebAPI(final ServerDBService serverDBService, final UserDBService userDBService, final AppConfig appConfig) {
public WebAPI(ServerDBService serverDBService, UserDBService userDBService, AppConfig appConfig) {
this.serverDBService = serverDBService;
this.userDBService = userDBService;
this.appConfig = appConfig;
@ -30,12 +31,12 @@ public final class WebAPI implements Runnable {
@Override
public void run() {
logger.info("Starting web application");
final Javalin app = Javalin.create(config -> {
var app = Javalin.create(config -> {
if (appConfig.debug()) config.plugins.enableDevLogging();
config.http.prefer405over404 = true; // return 405 instead of 404 if path is mapped to different HTTP method
config.http.defaultContentType = "application/json";
});
final var waitForShutdown = new CompletableFuture<Void>();
var waitForShutdown = new CompletableFuture<Void>();
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
logger.info("Shutting down web application");
app.stop();

View file

@ -10,21 +10,18 @@ import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
import java.util.Optional;
public final class ConfigCrudHandler implements CrudHandler {
public class ConfigCrudHandler implements CrudHandler {
private static final Logger logger = LoggerFactory.getLogger(ConfigCrudHandler.class);
private final ServerDBService serverDBService;
public ConfigCrudHandler(final ServerDBService serverDBService) {
public ConfigCrudHandler(ServerDBService serverDBService) {
this.serverDBService = serverDBService;
}
@Override
public void create(@NotNull final Context context) {
final ServerConfig.NewServerConfig body = context.bodyAsClass(ServerConfig.NewServerConfig.class);
public void create(@NotNull Context context) {
var body = context.bodyAsClass(ServerConfig.NewServerConfig.class);
try {
this.serverDBService.addConfig(body);
} catch (DBException e) {
@ -35,10 +32,9 @@ public final class ConfigCrudHandler implements CrudHandler {
}
@Override
public void delete(@NotNull final Context context, @NotNull final String s) {
final var id = Long.parseLong(s);
public void delete(@NotNull Context context, @NotNull String s) {
try {
this.serverDBService.deleteConfig(id);
this.serverDBService.deleteConfig(Long.parseLong(s));
context.status(HttpStatus.NO_CONTENT);
} catch (DBException e) {
logger.error("Deleting configuration with id: " + s, e);
@ -47,9 +43,9 @@ public final class ConfigCrudHandler implements CrudHandler {
}
@Override
public void getAll(@NotNull final Context context) {
public void getAll(@NotNull Context context) {
try {
final List<ServerConfig> allowedServers = this.serverDBService.getAllConfigs();
var allowedServers = this.serverDBService.getAllConfigs();
context.json(allowedServers);
} catch (DBException e) {
logger.error("Getting all server configs", e);
@ -58,10 +54,10 @@ public final class ConfigCrudHandler implements CrudHandler {
}
@Override
public void getOne(@NotNull final Context context, @NotNull final String s) {
final var id = Long.parseLong(s);
public void getOne(@NotNull Context context, @NotNull String s) {
var id = Long.parseLong(s);
try {
final Optional<ServerConfig> server = this.serverDBService.getConfigBy(id);
var server = this.serverDBService.getConfigBy(id);
if (server.isEmpty()) {
context.status(HttpStatus.NOT_FOUND);
return;
@ -74,9 +70,9 @@ public final class ConfigCrudHandler implements CrudHandler {
}
@Override
public void update(@NotNull final Context context, @NotNull final String idString) {
final ServerConfig.NewServerConfig body = context.bodyAsClass(ServerConfig.NewServerConfig.class);
final var id = Long.parseLong(idString);
public void update(@NotNull Context context, @NotNull String idString) {
var body = context.bodyAsClass(ServerConfig.NewServerConfig.class);
var id = Long.parseLong(idString);
try {
this.serverDBService.updateConfig(id, body);

View file

@ -10,21 +10,18 @@ import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
import java.util.Optional;
public final class ConfigUserCrudHandler implements CrudHandler {
public class ConfigUserCrudHandler implements CrudHandler {
private static final Logger logger = LoggerFactory.getLogger(ConfigUserCrudHandler.class);
private final UserDBService userDBService;
public ConfigUserCrudHandler(final UserDBService userDBService) {
public ConfigUserCrudHandler(UserDBService userDBService) {
this.userDBService = userDBService;
}
@Override
public void create(@NotNull final Context context) {
final UserConfig.NewUserConfig body = context.bodyAsClass(UserConfig.NewUserConfig.class);
public void create(@NotNull Context context) {
var body = context.bodyAsClass(UserConfig.NewUserConfig.class);
try {
this.userDBService.addConfig(body);
} catch (DBException e) {
@ -36,10 +33,9 @@ public final class ConfigUserCrudHandler implements CrudHandler {
}
@Override
public void delete(@NotNull final Context context, @NotNull final String s) {
final var id = Long.parseLong(s);
public void delete(@NotNull Context context, @NotNull String s) {
try {
this.userDBService.deleteConfig(id);
this.userDBService.deleteConfig(Long.parseLong(s));
context.status(HttpStatus.NO_CONTENT);
} catch (DBException e) {
logger.error("Deleting configuration with id: " + s, e);
@ -48,9 +44,9 @@ public final class ConfigUserCrudHandler implements CrudHandler {
}
@Override
public void getAll(@NotNull final Context context) {
public void getAll(@NotNull Context context) {
try {
final List<UserConfig> allowedServers = this.userDBService.getAllConfigs();
var allowedServers = this.userDBService.getAllConfigs();
context.json(allowedServers);
} catch (DBException e) {
logger.error("Getting all user configs", e);
@ -59,10 +55,10 @@ public final class ConfigUserCrudHandler implements CrudHandler {
}
@Override
public void getOne(@NotNull final Context context, @NotNull final String s) {
final var id = Long.parseLong(s);
public void getOne(@NotNull Context context, @NotNull String s) {
var id = Long.parseLong(s);
try {
final Optional<UserConfig> server = this.userDBService.getConfigBy(id);
var server = this.userDBService.getConfigBy(id);
if (server.isEmpty()) {
context.status(HttpStatus.NOT_FOUND);
return;
@ -76,8 +72,8 @@ public final class ConfigUserCrudHandler implements CrudHandler {
@Override
public void update(@NotNull Context context, @NotNull String idString) {
final UserConfig.NewUserConfig body = context.bodyAsClass(UserConfig.NewUserConfig.class);
final var id = Long.parseLong(idString);
var body = context.bodyAsClass(UserConfig.NewUserConfig.class);
var id = Long.parseLong(idString);
try {
this.userDBService.updateConfig(id, body);