Compare commits
No commits in common. "168a5d6c818327cbf1d08964acaee4d5646cf9c6" and "5c67a47806c298746576d7b9f3e9fdadd39c335e" have entirely different histories.
168a5d6c81
...
5c67a47806
24 changed files with 562 additions and 642 deletions
|
@ -4,47 +4,47 @@ import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import com.zaxxer.hikari.HikariConfig;
|
import com.zaxxer.hikari.HikariConfig;
|
||||||
import com.zaxxer.hikari.HikariDataSource;
|
import com.zaxxer.hikari.HikariDataSource;
|
||||||
import de.hhhammer.dchat.bot.openai.ChatGPTService;
|
import de.hhhammer.dchat.bot.openai.ChatGPTService;
|
||||||
import de.hhhammer.dchat.db.PostgresServerDBService;
|
import de.hhhammer.dchat.db.ServerDBService;
|
||||||
import de.hhhammer.dchat.db.PostgresUserDBService;
|
import de.hhhammer.dchat.db.UserDBService;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.net.http.HttpClient;
|
import java.net.http.HttpClient;
|
||||||
|
|
||||||
public final class App {
|
public class App {
|
||||||
private static final Logger logger = LoggerFactory.getLogger(App.class);
|
private static final Logger logger = LoggerFactory.getLogger(App.class);
|
||||||
|
|
||||||
public static void main(final String[] args) {
|
public static void main(String[] args) {
|
||||||
final String discordApiKey = System.getenv("DISCORD_API_KEY");
|
String discordApiKey = System.getenv("DISCORD_API_KEY");
|
||||||
if (discordApiKey == null) {
|
if (discordApiKey == null) {
|
||||||
logger.error("Missing environment variables: DISCORD_API_KEY");
|
logger.error("Missing environment variables: DISCORD_API_KEY");
|
||||||
System.exit(1);
|
System.exit(1);
|
||||||
}
|
}
|
||||||
final String openaiApiKey = System.getenv("OPENAI_API_KEY");
|
String openaiApiKey = System.getenv("OPENAI_API_KEY");
|
||||||
if (openaiApiKey == null) {
|
if (openaiApiKey == null) {
|
||||||
logger.error("Missing environment variables: OPENAI_API_KEY");
|
logger.error("Missing environment variables: OPENAI_API_KEY");
|
||||||
System.exit(1);
|
System.exit(1);
|
||||||
}
|
}
|
||||||
final String postgresUser = System.getenv("POSTGRES_USER");
|
String postgresUser = System.getenv("POSTGRES_USER");
|
||||||
final String postgresPassword = System.getenv("POSTGRES_PASSWORD");
|
String postgresPassword = System.getenv("POSTGRES_PASSWORD");
|
||||||
final String postgresUrl = System.getenv("POSTGRES_URL");
|
String postgresUrl = System.getenv("POSTGRES_URL");
|
||||||
if (postgresUser == null || postgresPassword == null || postgresUrl == null) {
|
if (postgresUser == null || postgresPassword == null || postgresUrl == null) {
|
||||||
logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL");
|
logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL");
|
||||||
System.exit(1);
|
System.exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
final var chatGPTService = new ChatGPTService(openaiApiKey, HttpClient.newHttpClient(), new ObjectMapper());
|
var chatGPTService = new ChatGPTService(openaiApiKey, HttpClient.newHttpClient(), new ObjectMapper());
|
||||||
|
|
||||||
final var config = new HikariConfig();
|
var config = new HikariConfig();
|
||||||
config.setJdbcUrl(postgresUrl);
|
config.setJdbcUrl(postgresUrl);
|
||||||
config.setUsername(postgresUser);
|
config.setUsername(postgresUser);
|
||||||
config.setPassword(postgresPassword);
|
config.setPassword(postgresPassword);
|
||||||
|
|
||||||
try (var ds = new HikariDataSource(config)) {
|
try (var ds = new HikariDataSource(config)) {
|
||||||
final var serverDBService = new PostgresServerDBService(ds);
|
var serverDBService = new ServerDBService(ds);
|
||||||
final var userDBService = new PostgresUserDBService(ds);
|
var userDBService = new UserDBService(ds);
|
||||||
|
|
||||||
final var discordBot = new DiscordBot(serverDBService, userDBService, chatGPTService, discordApiKey);
|
var discordBot = new DiscordBot(serverDBService, userDBService, chatGPTService, discordApiKey);
|
||||||
discordBot.run();
|
discordBot.run();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,6 @@ import de.hhhammer.dchat.db.UserDBService;
|
||||||
import org.javacord.api.DiscordApi;
|
import org.javacord.api.DiscordApi;
|
||||||
import org.javacord.api.DiscordApiBuilder;
|
import org.javacord.api.DiscordApiBuilder;
|
||||||
import org.javacord.api.interaction.SlashCommand;
|
import org.javacord.api.interaction.SlashCommand;
|
||||||
import org.javacord.api.interaction.SlashCommandInteraction;
|
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
@ -17,7 +16,7 @@ import java.util.concurrent.CompletableFuture;
|
||||||
import java.util.concurrent.ExecutionException;
|
import java.util.concurrent.ExecutionException;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
public final class DiscordBot implements Runnable {
|
public class DiscordBot implements Runnable {
|
||||||
private static final Logger logger = LoggerFactory.getLogger(DiscordBot.class);
|
private static final Logger logger = LoggerFactory.getLogger(DiscordBot.class);
|
||||||
|
|
||||||
private final ServerDBService serverDBService;
|
private final ServerDBService serverDBService;
|
||||||
|
@ -25,7 +24,7 @@ public final class DiscordBot implements Runnable {
|
||||||
private final ChatGPTService chatGPTService;
|
private final ChatGPTService chatGPTService;
|
||||||
private final String discordApiKey;
|
private final String discordApiKey;
|
||||||
|
|
||||||
public DiscordBot(final ServerDBService serverDBService, final UserDBService userDBService, final ChatGPTService chatGPTService, final String discordApiKey) {
|
public DiscordBot(ServerDBService serverDBService, UserDBService userDBService, ChatGPTService chatGPTService, String discordApiKey) {
|
||||||
this.serverDBService = serverDBService;
|
this.serverDBService = serverDBService;
|
||||||
this.userDBService = userDBService;
|
this.userDBService = userDBService;
|
||||||
this.chatGPTService = chatGPTService;
|
this.chatGPTService = chatGPTService;
|
||||||
|
@ -35,29 +34,29 @@ public final class DiscordBot implements Runnable {
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
logger.info("Starting Discord application");
|
logger.info("Starting Discord application");
|
||||||
final DiscordApi discordApi = new DiscordApiBuilder()
|
DiscordApi discordApi = new DiscordApiBuilder()
|
||||||
.setToken(discordApiKey)
|
.setToken(discordApiKey)
|
||||||
.login()
|
.login()
|
||||||
.join();
|
.join();
|
||||||
discordApi.setMessageCacheSize(10, 60*60);
|
discordApi.setMessageCacheSize(10, 60*60);
|
||||||
final var future = new CompletableFuture<Void>();
|
var future = new CompletableFuture<Void>();
|
||||||
Runtime.getRuntime().addShutdownHook(Thread.ofVirtual().unstarted(() -> {
|
Runtime.getRuntime().addShutdownHook(Thread.ofVirtual().unstarted(() -> {
|
||||||
logger.info("Shutting down Discord application");
|
logger.info("Shutting down Discord application");
|
||||||
discordApi.disconnect().thenAccept(future::complete);
|
discordApi.disconnect().thenAccept(future::complete);
|
||||||
}));
|
}));
|
||||||
final SlashCommand token = SlashCommand.with("tokens", "Check how many tokens where spend on this server")
|
var token = SlashCommand.with("tokens", "Check how many tokens where spend on this server")
|
||||||
.createGlobal(discordApi)
|
.createGlobal(discordApi)
|
||||||
.join();
|
.join();
|
||||||
|
|
||||||
discordApi.addSlashCommandCreateListener(event -> {
|
discordApi.addSlashCommandCreateListener(event -> {
|
||||||
logger.debug("Event? " + event.getSlashCommandInteraction().getFullCommandName());
|
logger.debug("Event? " + event.getSlashCommandInteraction().getFullCommandName());
|
||||||
final SlashCommandInteraction command = event.getSlashCommandInteraction();
|
var command = event.getSlashCommandInteraction();
|
||||||
if (token.getFullCommandNames().contains(command.getFullCommandName())) {
|
if (token.getFullCommandNames().contains(command.getFullCommandName())) {
|
||||||
event.getInteraction()
|
event.getInteraction()
|
||||||
.respondLater()
|
.respondLater()
|
||||||
.orTimeout(30, TimeUnit.SECONDS)
|
.orTimeout(30, TimeUnit.SECONDS)
|
||||||
.thenAccept((interactionOriginalResponseUpdater) -> {
|
.thenAccept((interactionOriginalResponseUpdater) -> {
|
||||||
final long tokens = event.getInteraction().getServer().isPresent() ?
|
var tokens = event.getInteraction().getServer().isPresent() ?
|
||||||
this.serverDBService.tokensOfLast30Days(String.valueOf(event.getInteraction().getServer().get().getId())) :
|
this.serverDBService.tokensOfLast30Days(String.valueOf(event.getInteraction().getServer().get().getId())) :
|
||||||
this.userDBService.tokensOfLast30Days(String.valueOf(event.getInteraction().getUser().getId()));
|
this.userDBService.tokensOfLast30Days(String.valueOf(event.getInteraction().getUser().getId()));
|
||||||
interactionOriginalResponseUpdater.setContent("" + tokens).update();
|
interactionOriginalResponseUpdater.setContent("" + tokens).update();
|
||||||
|
|
|
@ -9,17 +9,17 @@ import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
public final class MessageCreateHandler implements MessageCreateListener {
|
public class MessageCreateHandler implements MessageCreateListener {
|
||||||
private static final Logger logger = LoggerFactory.getLogger(MessageCreateHandler.class);
|
private static final Logger logger = LoggerFactory.getLogger(MessageCreateHandler.class);
|
||||||
|
|
||||||
private final MessageHandler messageHandler;
|
private final MessageHandler messageHandler;
|
||||||
|
|
||||||
public MessageCreateHandler(final MessageHandler messageHandler) {
|
public MessageCreateHandler(MessageHandler messageHandler) {
|
||||||
this.messageHandler = messageHandler;
|
this.messageHandler = messageHandler;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onMessageCreate(final MessageCreateEvent event) {
|
public void onMessageCreate(MessageCreateEvent event) {
|
||||||
Thread.ofVirtual().start(() -> {
|
Thread.ofVirtual().start(() -> {
|
||||||
if (!event.canYouReadContent() || event.getMessageAuthor().isBotUser() || !isNormalOrReplyMessageType(event)) {
|
if (!event.canYouReadContent() || event.getMessageAuthor().isBotUser() || !isNormalOrReplyMessageType(event)) {
|
||||||
return;
|
return;
|
||||||
|
@ -36,15 +36,15 @@ public final class MessageCreateHandler implements MessageCreateListener {
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
this.messageHandler.handle(event);
|
this.messageHandler.handle(event);
|
||||||
} catch (final ResponseException | IOException | InterruptedException e) {
|
} catch (ResponseException | IOException | InterruptedException e) {
|
||||||
logger.error("Reading a message from the listener", e);
|
logger.error("Reading a message from the listener", e);
|
||||||
event.getMessage().reply("Sorry but something went wrong :(");
|
event.getMessage().reply("Sorry but something went wrong :(");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean isNormalOrReplyMessageType(final MessageCreateEvent event) {
|
private boolean isNormalOrReplyMessageType(MessageCreateEvent event) {
|
||||||
final MessageType type = event.getMessage().getType();
|
MessageType type = event.getMessage().getType();
|
||||||
return type == MessageType.NORMAL || type == MessageType.REPLY;
|
return type == MessageType.NORMAL || type == MessageType.REPLY;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,10 +4,7 @@ import de.hhhammer.dchat.bot.openai.ChatGPTRequestBuilder;
|
||||||
import de.hhhammer.dchat.bot.openai.ChatGPTService;
|
import de.hhhammer.dchat.bot.openai.ChatGPTService;
|
||||||
import de.hhhammer.dchat.bot.openai.MessageContext.ReplyInteraction;
|
import de.hhhammer.dchat.bot.openai.MessageContext.ReplyInteraction;
|
||||||
import de.hhhammer.dchat.bot.openai.ResponseException;
|
import de.hhhammer.dchat.bot.openai.ResponseException;
|
||||||
import de.hhhammer.dchat.bot.openai.models.ChatGPTRequest;
|
|
||||||
import de.hhhammer.dchat.bot.openai.models.ChatGPTResponse;
|
|
||||||
import de.hhhammer.dchat.db.ServerDBService;
|
import de.hhhammer.dchat.db.ServerDBService;
|
||||||
import de.hhhammer.dchat.db.models.server.ServerConfig;
|
|
||||||
import de.hhhammer.dchat.db.models.server.ServerMessage;
|
import de.hhhammer.dchat.db.models.server.ServerMessage;
|
||||||
import org.javacord.api.entity.DiscordEntity;
|
import org.javacord.api.entity.DiscordEntity;
|
||||||
import org.javacord.api.entity.message.Message;
|
import org.javacord.api.entity.message.Message;
|
||||||
|
@ -20,44 +17,43 @@ import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
|
||||||
|
|
||||||
public final class ServerMessageHandler implements MessageHandler {
|
public class ServerMessageHandler implements MessageHandler {
|
||||||
|
|
||||||
private static final Logger logger = LoggerFactory.getLogger(ServerMessageHandler.class);
|
private static final Logger logger = LoggerFactory.getLogger(ServerMessageHandler.class);
|
||||||
private final ServerDBService serverDBService;
|
private final ServerDBService serverDBService;
|
||||||
private final ChatGPTService chatGPTService;
|
private final ChatGPTService chatGPTService;
|
||||||
|
|
||||||
public ServerMessageHandler(final ServerDBService serverDBService, final ChatGPTService chatGPTService) {
|
public ServerMessageHandler(ServerDBService serverDBService, ChatGPTService chatGPTService) {
|
||||||
this.serverDBService = serverDBService;
|
this.serverDBService = serverDBService;
|
||||||
this.chatGPTService = chatGPTService;
|
this.chatGPTService = chatGPTService;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void handle(final MessageCreateEvent event) throws ResponseException, IOException, InterruptedException {
|
public void handle(MessageCreateEvent event) throws ResponseException, IOException, InterruptedException {
|
||||||
final String content = extractContent(event);
|
String content = extractContent(event);
|
||||||
final long serverId = event.getServer().get().getId();
|
var serverId = event.getServer().get().getId();
|
||||||
final String systemMessage = this.serverDBService.getConfig(String.valueOf(serverId)).get().systemMessage();
|
var systemMessage = this.serverDBService.getConfig(String.valueOf(serverId)).get().systemMessage();
|
||||||
final List<ReplyInteraction> messageContext = event.getMessage().getType() == MessageType.REPLY ? getContextMessages(event) : List.of();
|
List<ReplyInteraction> messageContext = event.getMessage().getType() == MessageType.REPLY ? getContextMessages(event) : List.of();
|
||||||
final ChatGPTRequest request = new ChatGPTRequestBuilder().contextRequest(messageContext, content, systemMessage);
|
var request = new ChatGPTRequestBuilder().contextRequest(messageContext, content, systemMessage);
|
||||||
final ChatGPTResponse response = this.chatGPTService.submit(request);
|
var response = this.chatGPTService.submit(request);
|
||||||
if (response.choices().isEmpty()) {
|
if (response.choices().isEmpty()) {
|
||||||
event.getMessage().reply("No response available");
|
event.getMessage().reply("No response available");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
final String answer = response.choices().get(0).message().content();
|
var answer = response.choices().get(0).message().content();
|
||||||
logServerMessage(event, response.usage().totalTokens());
|
logServerMessage(event, response.usage().totalTokens());
|
||||||
event.getMessage().reply(answer);
|
event.getMessage().reply(answer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isAllowed(final MessageCreateEvent event) {
|
public boolean isAllowed(MessageCreateEvent event) {
|
||||||
if (event.getServer().isEmpty()) {
|
if (event.getServer().isEmpty()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
final long serverId = event.getServer().get().getId();
|
var serverId = event.getServer().get().getId();
|
||||||
final Optional<ServerConfig> config = this.serverDBService.getConfig(String.valueOf(serverId));
|
var config = this.serverDBService.getConfig(String.valueOf(serverId));
|
||||||
if (config.isEmpty()) {
|
if (config.isEmpty()) {
|
||||||
logger.debug("Not allowed with id: " + serverId);
|
logger.debug("Not allowed with id: " + serverId);
|
||||||
return false;
|
return false;
|
||||||
|
@ -66,39 +62,39 @@ public final class ServerMessageHandler implements MessageHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean exceedsRate(final MessageCreateEvent event) {
|
public boolean exceedsRate(MessageCreateEvent event) {
|
||||||
final String serverId = String.valueOf(event.getServer().get().getId());
|
var serverId = String.valueOf(event.getServer().get().getId());
|
||||||
final Optional<ServerConfig> config = this.serverDBService.getConfig(serverId);
|
var config = this.serverDBService.getConfig(serverId);
|
||||||
if (config.isEmpty()) {
|
if (config.isEmpty()) {
|
||||||
logger.error("Missing configuration for server with id: " + serverId);
|
logger.error("Missing configuration for server with id: " + serverId);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
final int rateLimit = config.get().rateLimit();
|
var rateLimit = config.get().rateLimit();
|
||||||
final int countMessagesInLastMinute = this.serverDBService.countMessagesInLastMinute(serverId);
|
var countMessagesInLastMinute = this.serverDBService.countMessagesInLastMinute(serverId);
|
||||||
|
|
||||||
return countMessagesInLastMinute >= rateLimit;
|
return countMessagesInLastMinute >= rateLimit;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean canHandle(final MessageCreateEvent event) {
|
public boolean canHandle(MessageCreateEvent event) {
|
||||||
return event.isServerMessage();
|
return event.isServerMessage();
|
||||||
}
|
}
|
||||||
|
|
||||||
private void logServerMessage(final MessageCreateEvent event, final int tokens) {
|
private void logServerMessage(MessageCreateEvent event, int tokens) {
|
||||||
final long serverId = event.getServer().map(DiscordEntity::getId).get();
|
var serverId = event.getServer().map(DiscordEntity::getId).get();
|
||||||
final long userId = event.getMessageAuthor().getId();
|
var userId = event.getMessageAuthor().getId();
|
||||||
|
|
||||||
final var serverMessage = new ServerMessage.NewServerMessage(String.valueOf(serverId), userId, tokens);
|
var serverMessage = new ServerMessage.NewServerMessage(String.valueOf(serverId), userId, tokens);
|
||||||
this.serverDBService.addMessage(serverMessage);
|
this.serverDBService.addMessage(serverMessage);
|
||||||
}
|
}
|
||||||
|
|
||||||
private String extractContent(final MessageCreateEvent event) {
|
private String extractContent(MessageCreateEvent event) {
|
||||||
final long ownId = event.getApi().getYourself().getId();
|
long ownId = event.getApi().getYourself().getId();
|
||||||
return event.getMessageContent().replaceFirst("<" + ownId + "> ", "");
|
return event.getMessageContent().replaceFirst("<" + ownId + "> ", "");
|
||||||
}
|
}
|
||||||
|
|
||||||
@NotNull
|
@NotNull
|
||||||
private List<ReplyInteraction> getContextMessages(final MessageCreateEvent event) {
|
private List<ReplyInteraction> getContextMessages(MessageCreateEvent event) {
|
||||||
return event.getMessage()
|
return event.getMessage()
|
||||||
.getMessageReference()
|
.getMessageReference()
|
||||||
.map(MessageReference::getMessage)
|
.map(MessageReference::getMessage)
|
||||||
|
|
|
@ -4,10 +4,7 @@ import de.hhhammer.dchat.bot.openai.ChatGPTRequestBuilder;
|
||||||
import de.hhhammer.dchat.bot.openai.ChatGPTService;
|
import de.hhhammer.dchat.bot.openai.ChatGPTService;
|
||||||
import de.hhhammer.dchat.bot.openai.MessageContext.PreviousInteraction;
|
import de.hhhammer.dchat.bot.openai.MessageContext.PreviousInteraction;
|
||||||
import de.hhhammer.dchat.bot.openai.ResponseException;
|
import de.hhhammer.dchat.bot.openai.ResponseException;
|
||||||
import de.hhhammer.dchat.bot.openai.models.ChatGPTRequest;
|
|
||||||
import de.hhhammer.dchat.bot.openai.models.ChatGPTResponse;
|
|
||||||
import de.hhhammer.dchat.db.UserDBService;
|
import de.hhhammer.dchat.db.UserDBService;
|
||||||
import de.hhhammer.dchat.db.models.user.UserConfig;
|
|
||||||
import de.hhhammer.dchat.db.models.user.UserMessage;
|
import de.hhhammer.dchat.db.models.user.UserMessage;
|
||||||
import org.javacord.api.event.message.MessageCreateEvent;
|
import org.javacord.api.event.message.MessageCreateEvent;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
|
@ -15,47 +12,46 @@ import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
|
||||||
|
|
||||||
public final class UserMessageHandler implements MessageHandler {
|
public class UserMessageHandler implements MessageHandler {
|
||||||
private static final Logger logger = LoggerFactory.getLogger(UserMessageHandler.class);
|
private static final Logger logger = LoggerFactory.getLogger(UserMessageHandler.class);
|
||||||
private final UserDBService userDBService;
|
private final UserDBService userDBService;
|
||||||
private final ChatGPTService chatGPTService;
|
private final ChatGPTService chatGPTService;
|
||||||
|
|
||||||
public UserMessageHandler(final UserDBService userDBService, final ChatGPTService chatGPTService) {
|
public UserMessageHandler(UserDBService userDBService, ChatGPTService chatGPTService) {
|
||||||
this.userDBService = userDBService;
|
this.userDBService = userDBService;
|
||||||
this.chatGPTService = chatGPTService;
|
this.chatGPTService = chatGPTService;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void handle(final MessageCreateEvent event) throws ResponseException, IOException, InterruptedException {
|
public void handle(MessageCreateEvent event) throws ResponseException, IOException, InterruptedException {
|
||||||
final String content = event.getReadableMessageContent();
|
String content = event.getReadableMessageContent();
|
||||||
final String userId = String.valueOf(event.getMessageAuthor().getId());
|
var userId = String.valueOf(event.getMessageAuthor().getId());
|
||||||
final UserConfig config = this.userDBService.getConfig(userId).get();
|
var config = this.userDBService.getConfig(userId).get();
|
||||||
final String systemMessage = config.systemMessage();
|
var systemMessage = config.systemMessage();
|
||||||
final List<PreviousInteraction> context = this.userDBService.getLastMessages(userId, config.contextLength())
|
List<PreviousInteraction> context = this.userDBService.getLastMessages(userId, config.contextLength())
|
||||||
.stream()
|
.stream()
|
||||||
.map(userMessage -> new PreviousInteraction(userMessage.question(), userMessage.answer()))
|
.map(userMessage -> new PreviousInteraction(userMessage.question(), userMessage.answer()))
|
||||||
.toList();
|
.toList();
|
||||||
final ChatGPTRequest request = new ChatGPTRequestBuilder().contextRequest(context, content, systemMessage);
|
var request = new ChatGPTRequestBuilder().contextRequest(context, content, systemMessage);
|
||||||
final ChatGPTResponse response = this.chatGPTService.submit(request);
|
var response = this.chatGPTService.submit(request);
|
||||||
if (response.choices().isEmpty()) {
|
if (response.choices().isEmpty()) {
|
||||||
event.getMessage().reply("No response available");
|
event.getMessage().reply("No response available");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
final String answer = response.choices().get(0).message().content();
|
var answer = response.choices().get(0).message().content();
|
||||||
logUserMessage(event, content, answer, response.usage().totalTokens());
|
logUserMessage(event, content, answer, response.usage().totalTokens());
|
||||||
event.getChannel().sendMessage(answer);
|
event.getChannel().sendMessage(answer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isAllowed(final MessageCreateEvent event) {
|
public boolean isAllowed(MessageCreateEvent event) {
|
||||||
if (event.getServer().isPresent()) {
|
if (event.getServer().isPresent()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
final long userId = event.getMessageAuthor().getId();
|
var userId = event.getMessageAuthor().getId();
|
||||||
final Optional<UserConfig> config = this.userDBService.getConfig(String.valueOf(userId));
|
var config = this.userDBService.getConfig(String.valueOf(userId));
|
||||||
if (config.isEmpty()) {
|
if (config.isEmpty()) {
|
||||||
logger.debug("Not allowed with id: " + userId);
|
logger.debug("Not allowed with id: " + userId);
|
||||||
return false;
|
return false;
|
||||||
|
@ -64,28 +60,28 @@ public final class UserMessageHandler implements MessageHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean exceedsRate(final MessageCreateEvent event) {
|
public boolean exceedsRate(MessageCreateEvent event) {
|
||||||
final String userId = String.valueOf(event.getMessageAuthor().getId());
|
var userId = String.valueOf(event.getMessageAuthor().getId());
|
||||||
final Optional<UserConfig> config = this.userDBService.getConfig(userId);
|
var config = this.userDBService.getConfig(userId);
|
||||||
if (config.isEmpty()) {
|
if (config.isEmpty()) {
|
||||||
logger.error("Missing configuration for userId with id: " + userId);
|
logger.error("Missing configuration for userId with id: " + userId);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
final int rateLimit = config.get().rateLimit();
|
var rateLimit = config.get().rateLimit();
|
||||||
final int countMessagesInLastMinute = this.userDBService.countMessagesInLastMinute(userId);
|
var countMessagesInLastMinute = this.userDBService.countMessagesInLastMinute(userId);
|
||||||
|
|
||||||
return countMessagesInLastMinute >= rateLimit;
|
return countMessagesInLastMinute >= rateLimit;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean canHandle(final MessageCreateEvent event) {
|
public boolean canHandle(MessageCreateEvent event) {
|
||||||
return event.isPrivateMessage();
|
return event.isPrivateMessage();
|
||||||
}
|
}
|
||||||
|
|
||||||
private void logUserMessage(final MessageCreateEvent event, final String question, final String answer, final int tokens) {
|
private void logUserMessage(MessageCreateEvent event, String question, String answer, int tokens) {
|
||||||
final long userId = event.getMessageAuthor().getId();
|
var userId = event.getMessageAuthor().getId();
|
||||||
|
|
||||||
final UserMessage.NewUserMessage userMessage = new UserMessage.NewUserMessage(String.valueOf(userId), question, answer, tokens);
|
var userMessage = new UserMessage.NewUserMessage(String.valueOf(userId), question, answer, tokens);
|
||||||
this.userDBService.addMessage(userMessage);
|
this.userDBService.addMessage(userMessage);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,19 +8,29 @@ import org.jetbrains.annotations.NotNull;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public final class ChatGPTRequestBuilder {
|
public class ChatGPTRequestBuilder {
|
||||||
private static final String smallContextModel = "gpt-3.5-turbo";
|
private static final String smallContextModel = "gpt-3.5-turbo";
|
||||||
private static final String bigContextModel = "gpt-3.5-turbo-16k";
|
private static final String bigContextModel = "gpt-3.5-turbo-16k";
|
||||||
|
|
||||||
public ChatGPTRequestBuilder() {
|
public ChatGPTRequestBuilder() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public ChatGPTRequest contextRequest(List<? extends MessageContext> contextMessages, String message, String systemMessage) {
|
||||||
|
List<ChatGPTRequest.Message> messages = getMessages(contextMessages, message, systemMessage);
|
||||||
|
String contextModel = contextMessages.size() <= 1 ? smallContextModel : bigContextModel;
|
||||||
|
return new ChatGPTRequest(
|
||||||
|
contextModel,
|
||||||
|
messages,
|
||||||
|
0.7f
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
@NotNull
|
@NotNull
|
||||||
private static List<ChatGPTRequest.Message> getMessages(final List<? extends MessageContext> contextMessages, final String message, final String systemMessage) {
|
private static List<ChatGPTRequest.Message> getMessages(List<? extends MessageContext> contextMessages, String message, String systemMessage) {
|
||||||
final ChatGPTRequest.Message systemMsg = new ChatGPTRequest.Message("system", systemMessage);
|
var systemMsg = new ChatGPTRequest.Message("system", systemMessage);
|
||||||
final List<ChatGPTRequest.Message> contextMsgs = getContextMessages(contextMessages);
|
List<ChatGPTRequest.Message> contextMsgs = getContextMessages(contextMessages);
|
||||||
final ChatGPTRequest.Message userMessage = new ChatGPTRequest.Message("user", message);
|
var userMessage = new ChatGPTRequest.Message("user", message);
|
||||||
final List<ChatGPTRequest.Message> messages = new ArrayList<>();
|
List<ChatGPTRequest.Message> messages = new ArrayList<>();
|
||||||
messages.add(systemMsg);
|
messages.add(systemMsg);
|
||||||
messages.addAll(contextMsgs);
|
messages.addAll(contextMsgs);
|
||||||
messages.add(userMessage);
|
messages.add(userMessage);
|
||||||
|
@ -28,7 +38,7 @@ public final class ChatGPTRequestBuilder {
|
||||||
}
|
}
|
||||||
|
|
||||||
@NotNull
|
@NotNull
|
||||||
private static List<ChatGPTRequest.Message> getContextMessages(final List<? extends MessageContext> contextMessages) {
|
private static List<ChatGPTRequest.Message> getContextMessages(List<? extends MessageContext> contextMessages) {
|
||||||
return contextMessages.stream()
|
return contextMessages.stream()
|
||||||
.map(ChatGPTRequestBuilder::mapContextMessages)
|
.map(ChatGPTRequestBuilder::mapContextMessages)
|
||||||
.flatMap(List::stream)
|
.flatMap(List::stream)
|
||||||
|
@ -36,7 +46,7 @@ public final class ChatGPTRequestBuilder {
|
||||||
}
|
}
|
||||||
|
|
||||||
@NotNull
|
@NotNull
|
||||||
private static List<ChatGPTRequest.Message> mapContextMessages(final MessageContext contextMessage) {
|
private static List<ChatGPTRequest.Message> mapContextMessages(MessageContext contextMessage) {
|
||||||
return switch (contextMessage) {
|
return switch (contextMessage) {
|
||||||
case PreviousInteraction previousInteractions -> List.of(
|
case PreviousInteraction previousInteractions -> List.of(
|
||||||
new ChatGPTRequest.Message("user", previousInteractions.question()),
|
new ChatGPTRequest.Message("user", previousInteractions.question()),
|
||||||
|
@ -46,14 +56,4 @@ public final class ChatGPTRequestBuilder {
|
||||||
List.of(new ChatGPTRequest.Message("assistant", replyInteractions.answer()));
|
List.of(new ChatGPTRequest.Message("assistant", replyInteractions.answer()));
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
public ChatGPTRequest contextRequest(final List<? extends MessageContext> contextMessages, final String message, final String systemMessage) {
|
|
||||||
final List<ChatGPTRequest.Message> messages = getMessages(contextMessages, message, systemMessage);
|
|
||||||
final String contextModel = contextMessages.size() <= 1 ? smallContextModel : bigContextModel;
|
|
||||||
return new ChatGPTRequest(
|
|
||||||
contextModel,
|
|
||||||
messages,
|
|
||||||
0.7f
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,35 +5,34 @@ import de.hhhammer.dchat.bot.openai.models.ChatGPTRequest;
|
||||||
import de.hhhammer.dchat.bot.openai.models.ChatGPTResponse;
|
import de.hhhammer.dchat.bot.openai.models.ChatGPTResponse;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStream;
|
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.net.http.HttpClient;
|
import java.net.http.HttpClient;
|
||||||
import java.net.http.HttpRequest;
|
import java.net.http.HttpRequest;
|
||||||
import java.net.http.HttpResponse;
|
import java.net.http.HttpResponse;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
|
|
||||||
public final class ChatGPTService {
|
public class ChatGPTService {
|
||||||
private static final String url = "https://api.openai.com/v1/chat/completions";
|
private static final String url = "https://api.openai.com/v1/chat/completions";
|
||||||
private final String apiKey;
|
private final String apiKey;
|
||||||
private final HttpClient httpClient;
|
private final HttpClient httpClient;
|
||||||
private final ObjectMapper mapper;
|
private final ObjectMapper mapper;
|
||||||
|
|
||||||
public ChatGPTService(final String apiKey, final HttpClient httpClient, final ObjectMapper mapper) {
|
public ChatGPTService(String apiKey, HttpClient httpClient, ObjectMapper mapper) {
|
||||||
this.apiKey = apiKey;
|
this.apiKey = apiKey;
|
||||||
this.httpClient = httpClient;
|
this.httpClient = httpClient;
|
||||||
this.mapper = mapper;
|
this.mapper = mapper;
|
||||||
}
|
}
|
||||||
|
|
||||||
public ChatGPTResponse submit(final ChatGPTRequest chatGPTRequest) throws IOException, InterruptedException, ResponseException {
|
public ChatGPTResponse submit(ChatGPTRequest chatGPTRequest) throws IOException, InterruptedException, ResponseException {
|
||||||
final byte[] data = mapper.writeValueAsBytes(chatGPTRequest);
|
var data = mapper.writeValueAsBytes(chatGPTRequest);
|
||||||
final HttpRequest request = HttpRequest.newBuilder(URI.create(url))
|
var request = HttpRequest.newBuilder(URI.create(url))
|
||||||
.POST(HttpRequest.BodyPublishers.ofByteArray(data))
|
.POST(HttpRequest.BodyPublishers.ofByteArray(data))
|
||||||
.setHeader("Content-Type", "application/json")
|
.setHeader("Content-Type", "application/json")
|
||||||
.setHeader("Authorization", "Bearer " + this.apiKey)
|
.setHeader("Authorization", "Bearer " + this.apiKey)
|
||||||
.timeout(Duration.ofMinutes(5))
|
.timeout(Duration.ofMinutes(5))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
final HttpResponse<InputStream> responseStream = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
var responseStream = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||||
if (responseStream.statusCode() != 200) {
|
if (responseStream.statusCode() != 200) {
|
||||||
throw new ResponseException("Response status code was not 200: " + responseStream.statusCode());
|
throw new ResponseException("Response status code was not 200: " + responseStream.statusCode());
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
package de.hhhammer.dchat.bot.openai;
|
package de.hhhammer.dchat.bot.openai;
|
||||||
|
|
||||||
public final class ResponseException extends Exception {
|
public class ResponseException extends Exception {
|
||||||
public ResponseException(final String message) {
|
public ResponseException(String message) {
|
||||||
super(message);
|
super(message);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
package de.hhhammer.dchat.db;
|
package de.hhhammer.dchat.db;
|
||||||
|
|
||||||
public final class DBException extends Exception {
|
public class DBException extends Exception {
|
||||||
public DBException(final String message, final Throwable cause) {
|
public DBException(String message, Throwable cause) {
|
||||||
super(message, cause);
|
super(message, cause);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,202 +0,0 @@
|
||||||
package de.hhhammer.dchat.db;
|
|
||||||
|
|
||||||
import de.hhhammer.dchat.db.models.server.ServerConfig;
|
|
||||||
import de.hhhammer.dchat.db.models.server.ServerMessage;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
import javax.sql.DataSource;
|
|
||||||
import java.sql.*;
|
|
||||||
import java.time.Instant;
|
|
||||||
import java.time.temporal.ChronoUnit;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.stream.StreamSupport;
|
|
||||||
|
|
||||||
public final class PostgresServerDBService implements ServerDBService {
|
|
||||||
private static final Logger logger = LoggerFactory.getLogger(PostgresServerDBService.class);
|
|
||||||
private final DataSource dataSource;
|
|
||||||
|
|
||||||
public PostgresServerDBService(final DataSource dataSource) {
|
|
||||||
this.dataSource = dataSource;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Optional<ServerConfig> getConfig(final String serverId) {
|
|
||||||
final var getServerConfig = """
|
|
||||||
SELECT * FROM server_configs WHERE server_id = ?
|
|
||||||
""";
|
|
||||||
try (final Connection con = dataSource.getConnection();
|
|
||||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
|
||||||
) {
|
|
||||||
pstmt.setString(1, serverId);
|
|
||||||
final ResultSet resultSet = pstmt.executeQuery();
|
|
||||||
final Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
|
|
||||||
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
|
|
||||||
} catch (final SQLException e) {
|
|
||||||
logger.error("Getting configuration for server with id: " + serverId, e);
|
|
||||||
} catch (final ResultSetIteratorException e) {
|
|
||||||
logger.error("Iterating over ServerConfig ResultSet for server with id: " + serverId, e);
|
|
||||||
}
|
|
||||||
return Optional.empty();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<ServerConfig> getAllConfigs() throws DBException {
|
|
||||||
final var getAllowedServerSql = """
|
|
||||||
SELECT * FROM server_configs
|
|
||||||
""";
|
|
||||||
try (final Connection con = dataSource.getConnection();
|
|
||||||
final PreparedStatement pstmt = con.prepareStatement(getAllowedServerSql)
|
|
||||||
) {
|
|
||||||
final ResultSet resultSet = pstmt.executeQuery();
|
|
||||||
final Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
|
|
||||||
return StreamSupport.stream(iterable.spliterator(), false).toList();
|
|
||||||
} catch (final SQLException e) {
|
|
||||||
throw new DBException("Loading all configs", e);
|
|
||||||
} catch (final ResultSetIteratorException e) {
|
|
||||||
throw new DBException("Iterating over configs", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Optional<ServerConfig> getConfigBy(final long id) throws DBException {
|
|
||||||
final var getServerConfig = """
|
|
||||||
SELECT * FROM server_configs WHERE id = ?
|
|
||||||
""";
|
|
||||||
try (final Connection con = dataSource.getConnection();
|
|
||||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
|
||||||
) {
|
|
||||||
pstmt.setLong(1, id);
|
|
||||||
final ResultSet resultSet = pstmt.executeQuery();
|
|
||||||
final Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
|
|
||||||
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
|
|
||||||
} catch (SQLException e) {
|
|
||||||
throw new DBException("Getting configuration with id: " + id, e);
|
|
||||||
} catch (ResultSetIteratorException e) {
|
|
||||||
throw new DBException("Iterating over ServerConfig ResultSet for id: " + id, e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addConfig(final ServerConfig.NewServerConfig newServerConfig) throws DBException {
|
|
||||||
final var getServerConfig = """
|
|
||||||
INSERT INTO server_configs (server_id, system_message, rate_limit) VALUES (?,?,?)
|
|
||||||
""";
|
|
||||||
try (final Connection con = dataSource.getConnection();
|
|
||||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
|
||||||
) {
|
|
||||||
pstmt.setString(1, newServerConfig.serverId());
|
|
||||||
pstmt.setString(2, newServerConfig.systemMessage());
|
|
||||||
pstmt.setInt(3, newServerConfig.rateLimit());
|
|
||||||
final int affectedRows = pstmt.executeUpdate();
|
|
||||||
if (affectedRows == 0) {
|
|
||||||
logger.error("No config added for server with id: " + newServerConfig.serverId());
|
|
||||||
}
|
|
||||||
} catch (final SQLException e) {
|
|
||||||
throw new DBException("Adding configuration to server with id: " + newServerConfig.serverId(), e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void updateConfig(final long id, final ServerConfig.NewServerConfig newServerConfig) throws DBException {
|
|
||||||
final var getServerConfig = """
|
|
||||||
UPDATE server_configs SET system_message = ?, rate_limit = ?, server_id = ? WHERE id = ?
|
|
||||||
""";
|
|
||||||
try (final Connection con = dataSource.getConnection();
|
|
||||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
|
||||||
) {
|
|
||||||
pstmt.setString(1, newServerConfig.systemMessage());
|
|
||||||
pstmt.setInt(2, newServerConfig.rateLimit());
|
|
||||||
pstmt.setString(3, newServerConfig.serverId());
|
|
||||||
pstmt.setLong(4, id);
|
|
||||||
final int affectedRows = pstmt.executeUpdate();
|
|
||||||
if (affectedRows == 0) {
|
|
||||||
logger.error("No config update for server with id: " + newServerConfig.serverId());
|
|
||||||
}
|
|
||||||
} catch (final SQLException e) {
|
|
||||||
throw new DBException("Adding configuration to server with id: " + newServerConfig.serverId(), e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void deleteConfig(final long id) throws DBException {
|
|
||||||
final var getServerConfig = """
|
|
||||||
DELETE FROM server_configs WHERE id = ?
|
|
||||||
""";
|
|
||||||
try (final Connection con = dataSource.getConnection();
|
|
||||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
|
||||||
) {
|
|
||||||
pstmt.setLong(1, id);
|
|
||||||
final int affectedRows = pstmt.executeUpdate();
|
|
||||||
if (affectedRows == 0) {
|
|
||||||
logger.error("No config deleted for server with id: " + id);
|
|
||||||
}
|
|
||||||
} catch (final SQLException e) {
|
|
||||||
throw new DBException("Deleting configuration for server with id: " + id, e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int countMessagesInLastMinute(final String serverId) {
|
|
||||||
final var getServerConfig = """
|
|
||||||
SELECT count(*) FROM server_messages WHERE server_id = ? AND time <= ? and time >= ?
|
|
||||||
""";
|
|
||||||
try (final Connection con = dataSource.getConnection();
|
|
||||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
|
||||||
) {
|
|
||||||
pstmt.setString(1, serverId);
|
|
||||||
final var now = Instant.now();
|
|
||||||
pstmt.setTimestamp(2, Timestamp.from(now));
|
|
||||||
pstmt.setTimestamp(3, Timestamp.from(now.minus(1, ChronoUnit.MINUTES)));
|
|
||||||
final ResultSet resultSet = pstmt.executeQuery();
|
|
||||||
if (resultSet.next()) return resultSet.getInt(1);
|
|
||||||
} catch (final SQLException e) {
|
|
||||||
logger.error("Getting messages for server with id: " + serverId, e);
|
|
||||||
} catch (final ResultSetIteratorException e) {
|
|
||||||
logger.error("Iterating over ServerMessages ResultSet for server with id: " + serverId, e);
|
|
||||||
}
|
|
||||||
return Integer.MAX_VALUE;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addMessage(final ServerMessage.NewServerMessage serverMessage) {
|
|
||||||
final var getServerConfig = """
|
|
||||||
INSERT INTO server_messages (server_id, user_id, tokens) VALUES (?,?,?)
|
|
||||||
""";
|
|
||||||
try (final Connection con = dataSource.getConnection();
|
|
||||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
|
||||||
) {
|
|
||||||
pstmt.setString(1, serverMessage.serverId());
|
|
||||||
pstmt.setLong(2, serverMessage.userId());
|
|
||||||
pstmt.setInt(3, serverMessage.tokens());
|
|
||||||
final int affectedRows = pstmt.executeUpdate();
|
|
||||||
if (affectedRows == 0) {
|
|
||||||
logger.error("No message added for server with id: " + serverMessage.serverId());
|
|
||||||
}
|
|
||||||
} catch (final SQLException e) {
|
|
||||||
logger.error("Adding message to server with id: " + serverMessage.serverId(), e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long tokensOfLast30Days(final String serverId) {
|
|
||||||
final var countTokensOfLast30Days = """
|
|
||||||
SELECT sum(tokens) FROM server_messages WHERE server_id = ? AND time < ? AND time >= ?
|
|
||||||
""";
|
|
||||||
try (final Connection con = dataSource.getConnection();
|
|
||||||
final PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
|
|
||||||
) {
|
|
||||||
pstmt.setString(1, serverId);
|
|
||||||
final var now = Instant.now();
|
|
||||||
pstmt.setTimestamp(2, Timestamp.from(now));
|
|
||||||
pstmt.setTimestamp(3, Timestamp.from(now.minus(30, ChronoUnit.DAYS)));
|
|
||||||
final ResultSet resultSet = pstmt.executeQuery();
|
|
||||||
if (resultSet.next()) return resultSet.getLong(1);
|
|
||||||
} catch (final SQLException e) {
|
|
||||||
logger.error("Counting tokens of the last 30 days from server with id: " + serverId, e);
|
|
||||||
}
|
|
||||||
logger.error("No tokens found for server with id: " + serverId);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,226 +0,0 @@
|
||||||
package de.hhhammer.dchat.db;
|
|
||||||
|
|
||||||
import de.hhhammer.dchat.db.models.user.UserConfig;
|
|
||||||
import de.hhhammer.dchat.db.models.user.UserMessage;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
import javax.sql.DataSource;
|
|
||||||
import java.sql.*;
|
|
||||||
import java.time.Instant;
|
|
||||||
import java.time.temporal.ChronoUnit;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.stream.StreamSupport;
|
|
||||||
|
|
||||||
public final class PostgresUserDBService implements UserDBService {
|
|
||||||
private static final Logger logger = LoggerFactory.getLogger(PostgresUserDBService.class);
|
|
||||||
private final DataSource dataSource;
|
|
||||||
|
|
||||||
public PostgresUserDBService(final DataSource dataSource) {
|
|
||||||
this.dataSource = dataSource;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Optional<UserConfig> getConfig(final String userId) {
|
|
||||||
final var getServerConfig = """
|
|
||||||
SELECT * FROM user_configs WHERE user_id = ?
|
|
||||||
""";
|
|
||||||
try (final Connection con = dataSource.getConnection();
|
|
||||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
|
||||||
) {
|
|
||||||
pstmt.setString(1, userId);
|
|
||||||
final ResultSet resultSet = pstmt.executeQuery();
|
|
||||||
final Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
|
|
||||||
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
|
|
||||||
} catch (final SQLException e) {
|
|
||||||
logger.error("Getting configuration for user with id: " + userId, e);
|
|
||||||
} catch (final ResultSetIteratorException e) {
|
|
||||||
logger.error("Iterating over ServerConfig ResultSet for user with id: " + userId, e);
|
|
||||||
}
|
|
||||||
return Optional.empty();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Optional<UserConfig> getConfigBy(final long id) throws DBException {
|
|
||||||
final var getServerConfig = """
|
|
||||||
SELECT * FROM user_configs WHERE id = ?
|
|
||||||
""";
|
|
||||||
try (final Connection con = dataSource.getConnection();
|
|
||||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
|
||||||
) {
|
|
||||||
pstmt.setLong(1, id);
|
|
||||||
final ResultSet resultSet = pstmt.executeQuery();
|
|
||||||
final Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
|
|
||||||
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
|
|
||||||
} catch (final SQLException e) {
|
|
||||||
throw new DBException("Getting configuration id: " + id, e);
|
|
||||||
} catch (final ResultSetIteratorException e) {
|
|
||||||
throw new DBException("Iterating over UserConfig ResultSet with id: " + id, e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<UserConfig> getAllConfigs() throws DBException {
|
|
||||||
final var getServerConfig = """
|
|
||||||
SELECT * FROM user_configs
|
|
||||||
""";
|
|
||||||
try (final Connection con = dataSource.getConnection();
|
|
||||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
|
||||||
) {
|
|
||||||
final ResultSet resultSet = pstmt.executeQuery();
|
|
||||||
final Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
|
|
||||||
return StreamSupport.stream(iterable.spliterator(), false).toList();
|
|
||||||
} catch (final SQLException e) {
|
|
||||||
throw new DBException("Getting all configurations", e);
|
|
||||||
} catch (final ResultSetIteratorException e) {
|
|
||||||
throw new DBException("Iterating over all UserConfig ResultSet", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addConfig(UserConfig.NewUserConfig newUserConfig) throws DBException {
|
|
||||||
final var getServerConfig = """
|
|
||||||
INSERT INTO user_configs (user_id, system_message, context_length, rate_limit) VALUES (?,?,?,?)
|
|
||||||
""";
|
|
||||||
try (final Connection con = dataSource.getConnection();
|
|
||||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
|
||||||
) {
|
|
||||||
pstmt.setString(1, newUserConfig.userId());
|
|
||||||
pstmt.setString(2, newUserConfig.systemMessage());
|
|
||||||
pstmt.setInt(3, newUserConfig.contextLength());
|
|
||||||
pstmt.setInt(4, newUserConfig.rateLimit());
|
|
||||||
final int affectedRows = pstmt.executeUpdate();
|
|
||||||
if (affectedRows == 0) {
|
|
||||||
logger.error("No config added for user with id: " + newUserConfig.userId());
|
|
||||||
}
|
|
||||||
} catch (final SQLException e) {
|
|
||||||
throw new DBException("Adding configuration for user with id: " + newUserConfig.userId(), e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void updateConfig(final long id, final UserConfig.NewUserConfig newUserConfig) throws DBException {
|
|
||||||
final var getServerConfig = """
|
|
||||||
UPDATE user_configs SET system_message = ?, context_length = ?, rate_limit = ?, user_id = ? WHERE id = ?
|
|
||||||
""";
|
|
||||||
try (final Connection con = dataSource.getConnection();
|
|
||||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
|
||||||
) {
|
|
||||||
pstmt.setString(1, newUserConfig.systemMessage());
|
|
||||||
pstmt.setInt(2, newUserConfig.rateLimit());
|
|
||||||
pstmt.setLong(3, newUserConfig.contextLength());
|
|
||||||
pstmt.setString(4, newUserConfig.userId());
|
|
||||||
pstmt.setLong(5, id);
|
|
||||||
final int affectedRows = pstmt.executeUpdate();
|
|
||||||
if (affectedRows == 0) {
|
|
||||||
logger.error("No config update with id: " + id);
|
|
||||||
}
|
|
||||||
} catch (final SQLException e) {
|
|
||||||
throw new DBException("Updating configuration with id: " + id, e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void deleteConfig(final long id) throws DBException {
|
|
||||||
final var getServerConfig = """
|
|
||||||
DELETE FROM user_configs WHERE id = ?
|
|
||||||
""";
|
|
||||||
try (final Connection con = dataSource.getConnection();
|
|
||||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
|
||||||
) {
|
|
||||||
pstmt.setLong(1, id);
|
|
||||||
final int affectedRows = pstmt.executeUpdate();
|
|
||||||
if (affectedRows == 0) {
|
|
||||||
logger.error("No config deleted for user with id: " + id);
|
|
||||||
}
|
|
||||||
} catch (final SQLException e) {
|
|
||||||
throw new DBException("Deleting configuration with id: " + id, e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int countMessagesInLastMinute(final String userId) {
|
|
||||||
final var getServerConfig = """
|
|
||||||
SELECT count(*) FROM user_messages WHERE user_id = ? AND time <= ? and time >= ?
|
|
||||||
""";
|
|
||||||
try (final Connection con = dataSource.getConnection();
|
|
||||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
|
||||||
) {
|
|
||||||
pstmt.setString(1, userId);
|
|
||||||
final var now = Instant.now();
|
|
||||||
pstmt.setTimestamp(2, Timestamp.from(now));
|
|
||||||
pstmt.setTimestamp(3, Timestamp.from(now.minus(1, ChronoUnit.MINUTES)));
|
|
||||||
final ResultSet resultSet = pstmt.executeQuery();
|
|
||||||
if (resultSet.next()) return resultSet.getInt(1);
|
|
||||||
} catch (final SQLException e) {
|
|
||||||
logger.error("Getting messages for user with id: " + userId, e);
|
|
||||||
} catch (final ResultSetIteratorException e) {
|
|
||||||
logger.error("Iterating over ServerMessages ResultSet for user with id: " + userId, e);
|
|
||||||
}
|
|
||||||
return Integer.MAX_VALUE;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addMessage(final UserMessage.NewUserMessage newUserMessage) {
|
|
||||||
final var getServerConfig = """
|
|
||||||
INSERT INTO user_messages (user_id, question, answer, tokens) VALUES (?,?,?,?)
|
|
||||||
""";
|
|
||||||
try (final Connection con = dataSource.getConnection();
|
|
||||||
final PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
|
||||||
) {
|
|
||||||
pstmt.setString(1, newUserMessage.userId());
|
|
||||||
pstmt.setString(2, newUserMessage.question());
|
|
||||||
pstmt.setString(3, newUserMessage.answer());
|
|
||||||
pstmt.setInt(4, newUserMessage.tokens());
|
|
||||||
final int affectedRows = pstmt.executeUpdate();
|
|
||||||
if (affectedRows == 0) {
|
|
||||||
logger.error("No message added for user with id: " + newUserMessage.userId());
|
|
||||||
}
|
|
||||||
} catch (final SQLException e) {
|
|
||||||
logger.error("Adding message to user with id: " + newUserMessage.userId(), e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<UserMessage> getLastMessages(final String userId, final int limit) {
|
|
||||||
final var getLastMessages = """
|
|
||||||
SELECT * FROM user_messages WHERE user_id = ? ORDER BY time DESC LIMIT ?
|
|
||||||
""";
|
|
||||||
try (final Connection con = dataSource.getConnection();
|
|
||||||
final PreparedStatement pstmt = con.prepareStatement(getLastMessages)
|
|
||||||
) {
|
|
||||||
pstmt.setString(1, userId);
|
|
||||||
pstmt.setInt(2, limit);
|
|
||||||
final ResultSet resultSet = pstmt.executeQuery();
|
|
||||||
final Iterable<UserMessage> iterable = () -> new ResultSetIterator<>(resultSet, new UserMessage.UserMessageResultSetTransformer());
|
|
||||||
return StreamSupport.stream(iterable.spliterator(), false).toList();
|
|
||||||
} catch (final SQLException e) {
|
|
||||||
logger.error("Fetching last messages for user whit id: " + userId, e);
|
|
||||||
} catch (final ResultSetIteratorException e) {
|
|
||||||
logger.error("Iterating over messages ResultSet from user with id: " + userId, e);
|
|
||||||
}
|
|
||||||
return List.of();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long tokensOfLast30Days(final String userId) {
|
|
||||||
final var countTokensOfLast30Days = """
|
|
||||||
SELECT sum(tokens) FROM user_messages WHERE user_id = ? AND time < ? AND time >= ?
|
|
||||||
""";
|
|
||||||
try (final Connection con = dataSource.getConnection();
|
|
||||||
final PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
|
|
||||||
) {
|
|
||||||
pstmt.setString(1, userId);
|
|
||||||
final var now = Instant.now();
|
|
||||||
pstmt.setTimestamp(2, Timestamp.from(now));
|
|
||||||
pstmt.setTimestamp(3, Timestamp.from(now.minus(30, ChronoUnit.DAYS)));
|
|
||||||
final ResultSet resultSet = pstmt.executeQuery();
|
|
||||||
if (resultSet.next()) return resultSet.getLong(1);
|
|
||||||
} catch (final SQLException e) {
|
|
||||||
logger.error("Counting tokens of the last 30 days from user with id: " + userId, e);
|
|
||||||
}
|
|
||||||
logger.error("No tokens found for user with id: " + userId);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -4,11 +4,11 @@ import java.sql.ResultSet;
|
||||||
import java.sql.SQLException;
|
import java.sql.SQLException;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
|
|
||||||
public final class ResultSetIterator<T> implements Iterator<T> {
|
public class ResultSetIterator<T> implements Iterator<T> {
|
||||||
private final ResultSet resultSet;
|
private final ResultSet resultSet;
|
||||||
private final ResultSetTransformer<T> transformer;
|
private final ResultSetTransformer<T> transformer;
|
||||||
|
|
||||||
public ResultSetIterator(final ResultSet resultSet, final ResultSetTransformer<T> transformer) {
|
public ResultSetIterator(ResultSet resultSet, ResultSetTransformer<T> transformer) {
|
||||||
this.resultSet = resultSet;
|
this.resultSet = resultSet;
|
||||||
this.transformer = transformer;
|
this.transformer = transformer;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
package de.hhhammer.dchat.db;
|
package de.hhhammer.dchat.db;
|
||||||
|
|
||||||
public final class ResultSetIteratorException extends RuntimeException {
|
public class ResultSetIteratorException extends RuntimeException {
|
||||||
public ResultSetIteratorException(final Throwable cause) {
|
public ResultSetIteratorException(Throwable cause) {
|
||||||
super(cause);
|
super(cause);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,26 +2,196 @@ package de.hhhammer.dchat.db;
|
||||||
|
|
||||||
import de.hhhammer.dchat.db.models.server.ServerConfig;
|
import de.hhhammer.dchat.db.models.server.ServerConfig;
|
||||||
import de.hhhammer.dchat.db.models.server.ServerMessage;
|
import de.hhhammer.dchat.db.models.server.ServerMessage;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import javax.sql.DataSource;
|
||||||
|
import java.sql.*;
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.time.temporal.ChronoUnit;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
import java.util.stream.StreamSupport;
|
||||||
|
|
||||||
public interface ServerDBService {
|
public class ServerDBService {
|
||||||
Optional<ServerConfig> getConfig(String serverId);
|
private static final Logger logger = LoggerFactory.getLogger(ServerDBService.class);
|
||||||
|
private final DataSource dataSource;
|
||||||
|
|
||||||
List<ServerConfig> getAllConfigs() throws DBException;
|
public ServerDBService(DataSource dataSource) {
|
||||||
|
this.dataSource = dataSource;
|
||||||
Optional<ServerConfig> getConfigBy(long id) throws DBException;
|
}
|
||||||
|
|
||||||
void addConfig(ServerConfig.NewServerConfig newServerConfig) throws DBException;
|
public Optional<ServerConfig> getConfig(String serverId) {
|
||||||
|
var getServerConfig = """
|
||||||
void updateConfig(long id, ServerConfig.NewServerConfig newServerConfig) throws DBException;
|
SELECT * FROM server_configs WHERE server_id = ?
|
||||||
|
""";
|
||||||
void deleteConfig(long id) throws DBException;
|
try (Connection con = dataSource.getConnection();
|
||||||
|
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||||
int countMessagesInLastMinute(String serverId);
|
) {
|
||||||
|
pstmt.setString(1, serverId);
|
||||||
void addMessage(ServerMessage.NewServerMessage serverMessage);
|
ResultSet resultSet = pstmt.executeQuery();
|
||||||
|
Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
|
||||||
long tokensOfLast30Days(String serverId);
|
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
|
||||||
|
} catch (SQLException e) {
|
||||||
|
logger.error("Getting configuration for server with id: " + serverId, e);
|
||||||
|
} catch (ResultSetIteratorException e) {
|
||||||
|
logger.error("Iterating over ServerConfig ResultSet for server with id: " + serverId, e);
|
||||||
|
return Optional.empty();
|
||||||
|
}
|
||||||
|
return Optional.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<ServerConfig> getAllConfigs() throws DBException {
|
||||||
|
var getAllowedServerSql = """
|
||||||
|
SELECT * FROM server_configs
|
||||||
|
""";
|
||||||
|
try (Connection con = dataSource.getConnection();
|
||||||
|
PreparedStatement pstmt = con.prepareStatement(getAllowedServerSql)
|
||||||
|
) {
|
||||||
|
ResultSet resultSet = pstmt.executeQuery();
|
||||||
|
Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
|
||||||
|
return StreamSupport.stream(iterable.spliterator(), false).toList();
|
||||||
|
} catch (SQLException e) {
|
||||||
|
throw new DBException("Loading all configs", e);
|
||||||
|
} catch (ResultSetIteratorException e) {
|
||||||
|
throw new DBException("Iterating over configs", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public Optional<ServerConfig> getConfigBy(long id) throws DBException {
|
||||||
|
var getServerConfig = """
|
||||||
|
SELECT * FROM server_configs WHERE id = ?
|
||||||
|
""";
|
||||||
|
try (Connection con = dataSource.getConnection();
|
||||||
|
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||||
|
) {
|
||||||
|
pstmt.setLong(1, id);
|
||||||
|
ResultSet resultSet = pstmt.executeQuery();
|
||||||
|
Iterable<ServerConfig> iterable = () -> new ResultSetIterator<>(resultSet, new ServerConfig.ServerConfigResultSetTransformer());
|
||||||
|
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
|
||||||
|
} catch (SQLException e) {
|
||||||
|
throw new DBException("Getting configuration with id: " + id, e);
|
||||||
|
} catch (ResultSetIteratorException e) {
|
||||||
|
throw new DBException("Iterating over ServerConfig ResultSet for id: " + id, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addConfig(ServerConfig.NewServerConfig newServerConfig) throws DBException {
|
||||||
|
var getServerConfig = """
|
||||||
|
INSERT INTO server_configs (server_id, system_message, rate_limit) VALUES (?,?,?)
|
||||||
|
""";
|
||||||
|
try (Connection con = dataSource.getConnection();
|
||||||
|
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||||
|
) {
|
||||||
|
pstmt.setString(1, newServerConfig.serverId());
|
||||||
|
pstmt.setString(2, newServerConfig.systemMessage());
|
||||||
|
pstmt.setInt(3, newServerConfig.rateLimit());
|
||||||
|
int affectedRows = pstmt.executeUpdate();
|
||||||
|
if (affectedRows == 0) {
|
||||||
|
logger.error("No config added for server with id: " + newServerConfig.serverId());
|
||||||
|
}
|
||||||
|
} catch (SQLException e) {
|
||||||
|
throw new DBException("Adding configuration to server with id: " + newServerConfig.serverId(), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void updateConfig(long id, ServerConfig.NewServerConfig newServerConfig) throws DBException {
|
||||||
|
var getServerConfig = """
|
||||||
|
UPDATE server_configs SET system_message = ?, rate_limit = ?, server_id = ? WHERE id = ?
|
||||||
|
""";
|
||||||
|
try (Connection con = dataSource.getConnection();
|
||||||
|
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||||
|
) {
|
||||||
|
pstmt.setString(1, newServerConfig.systemMessage());
|
||||||
|
pstmt.setInt(2, newServerConfig.rateLimit());
|
||||||
|
pstmt.setString(3, newServerConfig.serverId());
|
||||||
|
pstmt.setLong(4, id);
|
||||||
|
int affectedRows = pstmt.executeUpdate();
|
||||||
|
if (affectedRows == 0) {
|
||||||
|
logger.error("No config update for server with id: " + newServerConfig.serverId());
|
||||||
|
}
|
||||||
|
} catch (SQLException e) {
|
||||||
|
throw new DBException("Adding configuration to server with id: " + newServerConfig.serverId(), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void deleteConfig(long id) throws DBException {
|
||||||
|
var getServerConfig = """
|
||||||
|
DELETE FROM server_configs WHERE id = ?
|
||||||
|
""";
|
||||||
|
try (Connection con = dataSource.getConnection();
|
||||||
|
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||||
|
) {
|
||||||
|
pstmt.setLong(1, id);
|
||||||
|
int affectedRows = pstmt.executeUpdate();
|
||||||
|
if (affectedRows == 0) {
|
||||||
|
logger.error("No config deleted for server with id: " + id);
|
||||||
|
}
|
||||||
|
} catch (SQLException e) {
|
||||||
|
throw new DBException("Deleting configuration for server with id: " + id, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public int countMessagesInLastMinute(String serverId) {
|
||||||
|
var getServerConfig = """
|
||||||
|
SELECT count(*) FROM server_messages WHERE server_id = ? AND time <= ? and time >= ?
|
||||||
|
""";
|
||||||
|
try (Connection con = dataSource.getConnection();
|
||||||
|
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||||
|
) {
|
||||||
|
pstmt.setString(1, serverId);
|
||||||
|
var now = Instant.now();
|
||||||
|
pstmt.setTimestamp(2, Timestamp.from(now));
|
||||||
|
pstmt.setTimestamp(3, Timestamp.from(now.minus(1, ChronoUnit.MINUTES)));
|
||||||
|
ResultSet resultSet = pstmt.executeQuery();
|
||||||
|
resultSet.next();
|
||||||
|
return resultSet.getInt(1);
|
||||||
|
} catch (SQLException e) {
|
||||||
|
logger.error("Getting messages for server with id: " + serverId, e);
|
||||||
|
} catch (ResultSetIteratorException e) {
|
||||||
|
logger.error("Iterating over ServerMessages ResultSet for server with id: " + serverId, e);
|
||||||
|
return Integer.MAX_VALUE;
|
||||||
|
}
|
||||||
|
return Integer.MAX_VALUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addMessage(ServerMessage.NewServerMessage serverMessage) {
|
||||||
|
var getServerConfig = """
|
||||||
|
INSERT INTO server_messages (server_id, user_id, tokens) VALUES (?,?,?)
|
||||||
|
""";
|
||||||
|
try (Connection con = dataSource.getConnection();
|
||||||
|
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||||
|
) {
|
||||||
|
pstmt.setString(1, serverMessage.serverId());
|
||||||
|
pstmt.setLong(2, serverMessage.userId());
|
||||||
|
pstmt.setInt(3, serverMessage.tokens());
|
||||||
|
int affectedRows = pstmt.executeUpdate();
|
||||||
|
if (affectedRows == 0) {
|
||||||
|
logger.error("No message added for server with id: " + serverMessage.serverId());
|
||||||
|
}
|
||||||
|
} catch (SQLException e) {
|
||||||
|
logger.error("Adding message to server with id: " + serverMessage.serverId(), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public long tokensOfLast30Days(String serverId) {
|
||||||
|
var countTokensOfLast30Days = """
|
||||||
|
SELECT sum(tokens) FROM server_messages WHERE server_id = ? AND time < ? AND time >= ?
|
||||||
|
""";
|
||||||
|
try (Connection con = dataSource.getConnection();
|
||||||
|
PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
|
||||||
|
) {
|
||||||
|
pstmt.setString(1, serverId);
|
||||||
|
var now = Instant.now();
|
||||||
|
pstmt.setTimestamp(2, Timestamp.from(now));
|
||||||
|
pstmt.setTimestamp(3, Timestamp.from(now.minus(30, ChronoUnit.DAYS)));
|
||||||
|
ResultSet resultSet = pstmt.executeQuery();
|
||||||
|
resultSet.next();
|
||||||
|
return resultSet.getLong(1);
|
||||||
|
} catch (SQLException e) {
|
||||||
|
logger.error("Counting tokens of the last 30 days from server with id: " + serverId, e);
|
||||||
|
}
|
||||||
|
logger.error("No tokens found for server with id: " + serverId);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,28 +2,219 @@ package de.hhhammer.dchat.db;
|
||||||
|
|
||||||
import de.hhhammer.dchat.db.models.user.UserConfig;
|
import de.hhhammer.dchat.db.models.user.UserConfig;
|
||||||
import de.hhhammer.dchat.db.models.user.UserMessage;
|
import de.hhhammer.dchat.db.models.user.UserMessage;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import javax.sql.DataSource;
|
||||||
|
import java.sql.*;
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.time.temporal.ChronoUnit;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
import java.util.stream.StreamSupport;
|
||||||
|
|
||||||
public interface UserDBService {
|
public class UserDBService {
|
||||||
Optional<UserConfig> getConfig(String userId);
|
private static final Logger logger = LoggerFactory.getLogger(UserDBService.class);
|
||||||
|
private final DataSource dataSource;
|
||||||
|
|
||||||
Optional<UserConfig> getConfigBy(long id) throws DBException;
|
public UserDBService(DataSource dataSource) {
|
||||||
|
this.dataSource = dataSource;
|
||||||
List<UserConfig> getAllConfigs() throws DBException;
|
}
|
||||||
|
|
||||||
void addConfig(UserConfig.NewUserConfig newUserConfig) throws DBException;
|
public Optional<UserConfig> getConfig(String userId) {
|
||||||
|
var getServerConfig = """
|
||||||
void updateConfig(long id, UserConfig.NewUserConfig newUserConfig) throws DBException;
|
SELECT * FROM user_configs WHERE user_id = ?
|
||||||
|
""";
|
||||||
void deleteConfig(long id) throws DBException;
|
try (Connection con = dataSource.getConnection();
|
||||||
|
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||||
int countMessagesInLastMinute(String userId);
|
) {
|
||||||
|
pstmt.setString(1, userId);
|
||||||
void addMessage(UserMessage.NewUserMessage newUserMessage);
|
ResultSet resultSet = pstmt.executeQuery();
|
||||||
|
Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
|
||||||
List<UserMessage> getLastMessages(String userId, int limit);
|
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
|
||||||
|
} catch (SQLException e) {
|
||||||
long tokensOfLast30Days(String userId);
|
logger.error("Getting configuration for user with id: " + userId, e);
|
||||||
|
} catch (ResultSetIteratorException e) {
|
||||||
|
logger.error("Iterating over ServerConfig ResultSet for user with id: " + userId, e);
|
||||||
|
return Optional.empty();
|
||||||
|
}
|
||||||
|
return Optional.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Optional<UserConfig> getConfigBy(long id) throws DBException {
|
||||||
|
var getServerConfig = """
|
||||||
|
SELECT * FROM user_configs WHERE id = ?
|
||||||
|
""";
|
||||||
|
try (Connection con = dataSource.getConnection();
|
||||||
|
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||||
|
) {
|
||||||
|
pstmt.setLong(1, id);
|
||||||
|
ResultSet resultSet = pstmt.executeQuery();
|
||||||
|
Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
|
||||||
|
return StreamSupport.stream(iterable.spliterator(), false).findFirst();
|
||||||
|
} catch (SQLException e) {
|
||||||
|
throw new DBException("Getting configuration id: " + id, e);
|
||||||
|
} catch (ResultSetIteratorException e) {
|
||||||
|
throw new DBException("Iterating over UserConfig ResultSet with id: " + id, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<UserConfig> getAllConfigs() throws DBException {
|
||||||
|
var getServerConfig = """
|
||||||
|
SELECT * FROM user_configs
|
||||||
|
""";
|
||||||
|
try (Connection con = dataSource.getConnection();
|
||||||
|
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||||
|
) {
|
||||||
|
ResultSet resultSet = pstmt.executeQuery();
|
||||||
|
Iterable<UserConfig> iterable = () -> new ResultSetIterator<>(resultSet, new UserConfig.UserConfigResultSetTransformer());
|
||||||
|
return StreamSupport.stream(iterable.spliterator(), false).toList();
|
||||||
|
} catch (SQLException e) {
|
||||||
|
throw new DBException("Getting all configurations", e);
|
||||||
|
} catch (ResultSetIteratorException e) {
|
||||||
|
throw new DBException("Iterating over all UserConfig ResultSet", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addConfig(UserConfig.NewUserConfig newUserConfig) throws DBException {
|
||||||
|
var getServerConfig = """
|
||||||
|
INSERT INTO user_configs (user_id, system_message, context_length, rate_limit) VALUES (?,?,?,?)
|
||||||
|
""";
|
||||||
|
try (Connection con = dataSource.getConnection();
|
||||||
|
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||||
|
) {
|
||||||
|
pstmt.setString(1, newUserConfig.userId());
|
||||||
|
pstmt.setString(2, newUserConfig.systemMessage());
|
||||||
|
pstmt.setInt(3, newUserConfig.contextLength());
|
||||||
|
pstmt.setInt(4, newUserConfig.rateLimit());
|
||||||
|
int affectedRows = pstmt.executeUpdate();
|
||||||
|
if (affectedRows == 0) {
|
||||||
|
logger.error("No config added for user with id: " + newUserConfig.userId());
|
||||||
|
}
|
||||||
|
} catch (SQLException e) {
|
||||||
|
throw new DBException("Adding configuration for user with id: " + newUserConfig.userId(), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void updateConfig(long id, UserConfig.NewUserConfig newUserConfig) throws DBException {
|
||||||
|
var getServerConfig = """
|
||||||
|
UPDATE user_configs SET system_message = ?, context_length = ?, rate_limit = ?, user_id = ? WHERE id = ?
|
||||||
|
""";
|
||||||
|
try (Connection con = dataSource.getConnection();
|
||||||
|
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||||
|
) {
|
||||||
|
pstmt.setString(1, newUserConfig.systemMessage());
|
||||||
|
pstmt.setInt(2, newUserConfig.rateLimit());
|
||||||
|
pstmt.setLong(3, newUserConfig.contextLength());
|
||||||
|
pstmt.setString(4, newUserConfig.userId());
|
||||||
|
pstmt.setLong(5, id);
|
||||||
|
int affectedRows = pstmt.executeUpdate();
|
||||||
|
if (affectedRows == 0) {
|
||||||
|
logger.error("No config update with id: " + id);
|
||||||
|
}
|
||||||
|
} catch (SQLException e) {
|
||||||
|
throw new DBException("Updating configuration with id: " + id, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void deleteConfig(long id) throws DBException {
|
||||||
|
var getServerConfig = """
|
||||||
|
DELETE FROM user_configs WHERE id = ?
|
||||||
|
""";
|
||||||
|
try (Connection con = dataSource.getConnection();
|
||||||
|
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||||
|
) {
|
||||||
|
pstmt.setLong(1, id);
|
||||||
|
int affectedRows = pstmt.executeUpdate();
|
||||||
|
if (affectedRows == 0) {
|
||||||
|
logger.error("No config deleted for user with id: " + id);
|
||||||
|
}
|
||||||
|
} catch (SQLException e) {
|
||||||
|
throw new DBException("Deleting configuration with id: " + id, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public int countMessagesInLastMinute(String userId) {
|
||||||
|
var getServerConfig = """
|
||||||
|
SELECT count(*) FROM user_messages WHERE user_id = ? AND time <= ? and time >= ?
|
||||||
|
""";
|
||||||
|
try (Connection con = dataSource.getConnection();
|
||||||
|
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||||
|
) {
|
||||||
|
pstmt.setString(1, userId);
|
||||||
|
var now = Instant.now();
|
||||||
|
pstmt.setTimestamp(2, Timestamp.from(now));
|
||||||
|
pstmt.setTimestamp(3, Timestamp.from(now.minus(1, ChronoUnit.MINUTES)));
|
||||||
|
ResultSet resultSet = pstmt.executeQuery();
|
||||||
|
resultSet.next();
|
||||||
|
return resultSet.getInt(1);
|
||||||
|
} catch (SQLException e) {
|
||||||
|
logger.error("Getting messages for user with id: " + userId, e);
|
||||||
|
} catch (ResultSetIteratorException e) {
|
||||||
|
logger.error("Iterating over ServerMessages ResultSet for user with id: " + userId, e);
|
||||||
|
return Integer.MAX_VALUE;
|
||||||
|
}
|
||||||
|
return Integer.MAX_VALUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addMessage(UserMessage.NewUserMessage newUserMessage) {
|
||||||
|
var getServerConfig = """
|
||||||
|
INSERT INTO user_messages (user_id, question, answer, tokens) VALUES (?,?,?,?)
|
||||||
|
""";
|
||||||
|
try (Connection con = dataSource.getConnection();
|
||||||
|
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
|
||||||
|
) {
|
||||||
|
pstmt.setString(1, newUserMessage.userId());
|
||||||
|
pstmt.setString(2, newUserMessage.question());
|
||||||
|
pstmt.setString(3, newUserMessage.answer());
|
||||||
|
pstmt.setInt(4, newUserMessage.tokens());
|
||||||
|
int affectedRows = pstmt.executeUpdate();
|
||||||
|
if (affectedRows == 0) {
|
||||||
|
logger.error("No message added for user with id: " + newUserMessage.userId());
|
||||||
|
}
|
||||||
|
} catch (SQLException e) {
|
||||||
|
logger.error("Adding message to user with id: " + newUserMessage.userId(), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<UserMessage> getLastMessages(String userId, int limit) {
|
||||||
|
var getLastMessages = """
|
||||||
|
SELECT * FROM user_messages WHERE user_id = ? ORDER BY time DESC LIMIT ?
|
||||||
|
""";
|
||||||
|
try (Connection con = dataSource.getConnection();
|
||||||
|
PreparedStatement pstmt = con.prepareStatement(getLastMessages)
|
||||||
|
) {
|
||||||
|
pstmt.setString(1, userId);
|
||||||
|
pstmt.setInt(2, limit);
|
||||||
|
ResultSet resultSet = pstmt.executeQuery();
|
||||||
|
Iterable<UserMessage> iterable = () -> new ResultSetIterator<>(resultSet, new UserMessage.UserMessageResultSetTransformer());
|
||||||
|
return StreamSupport.stream(iterable.spliterator(), false).toList();
|
||||||
|
} catch (SQLException e) {
|
||||||
|
logger.error("Fetching last messages for user whit id: " + userId, e);
|
||||||
|
} catch (ResultSetIteratorException e) {
|
||||||
|
logger.error("Iterating over messages ResultSet from user with id: " + userId, e);
|
||||||
|
}
|
||||||
|
return List.of();
|
||||||
|
}
|
||||||
|
|
||||||
|
public long tokensOfLast30Days(String userId) {
|
||||||
|
var countTokensOfLast30Days = """
|
||||||
|
SELECT sum(tokens) FROM user_messages WHERE user_id = ? AND time < ? AND time >= ?
|
||||||
|
""";
|
||||||
|
try (Connection con = dataSource.getConnection();
|
||||||
|
PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
|
||||||
|
) {
|
||||||
|
pstmt.setString(1, userId);
|
||||||
|
var now = Instant.now();
|
||||||
|
pstmt.setTimestamp(2, Timestamp.from(now));
|
||||||
|
pstmt.setTimestamp(3, Timestamp.from(now.minus(30, ChronoUnit.DAYS)));
|
||||||
|
ResultSet resultSet = pstmt.executeQuery();
|
||||||
|
resultSet.next();
|
||||||
|
return resultSet.getLong(1);
|
||||||
|
} catch (SQLException e) {
|
||||||
|
logger.error("Counting tokens of the last 30 days from user with id: " + userId, e);
|
||||||
|
}
|
||||||
|
logger.error("No tokens found for user with id: " + userId);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,20 +6,20 @@ import org.slf4j.LoggerFactory;
|
||||||
/**
|
/**
|
||||||
* Hello world!
|
* Hello world!
|
||||||
*/
|
*/
|
||||||
public final class App {
|
public class App {
|
||||||
private static final Logger logger = LoggerFactory.getLogger(App.class);
|
private static final Logger logger = LoggerFactory.getLogger(App.class);
|
||||||
private static final String DB_MIGRATION_PATH = "db/schema.sql";
|
private static final String DB_MIGRATION_PATH = "db/schema.sql";
|
||||||
|
|
||||||
public static void main(final String[] args) {
|
public static void main(String[] args) {
|
||||||
final String postgresUser = System.getenv("POSTGRES_USER");
|
String postgresUser = System.getenv("POSTGRES_USER");
|
||||||
final String postgresPassword = System.getenv("POSTGRES_PASSWORD");
|
String postgresPassword = System.getenv("POSTGRES_PASSWORD");
|
||||||
final String postgresUrl = System.getenv("POSTGRES_URL");
|
String postgresUrl = System.getenv("POSTGRES_URL");
|
||||||
if (postgresUser == null || postgresPassword == null || postgresUrl == null) {
|
if (postgresUser == null || postgresPassword == null || postgresUrl == null) {
|
||||||
logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL");
|
logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL");
|
||||||
System.exit(1);
|
System.exit(1);
|
||||||
}
|
}
|
||||||
final var migrationExecutor = new MigrationExecutor(postgresUrl, postgresUser, postgresPassword);
|
var migrationExecutor = new MigrationExecutor(postgresUrl, postgresUser, postgresPassword);
|
||||||
final var dbMigrator = new DBMigrator(migrationExecutor, DB_MIGRATION_PATH);
|
var dbMigrator = new DBMigrator(migrationExecutor, DB_MIGRATION_PATH);
|
||||||
dbMigrator.run();
|
dbMigrator.run();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
package de.hhhammer.dchat.migration;
|
package de.hhhammer.dchat.migration;
|
||||||
|
|
||||||
public final class DBMigrationException extends Exception {
|
public class DBMigrationException extends Exception {
|
||||||
public DBMigrationException(final Throwable cause) {
|
public DBMigrationException(Throwable cause) {
|
||||||
super(cause);
|
super(cause);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,12 +6,12 @@ import org.slf4j.LoggerFactory;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
|
||||||
public final class DBMigrator implements Runnable {
|
public class DBMigrator implements Runnable {
|
||||||
private static final Logger logger = LoggerFactory.getLogger(DBMigrator.class);
|
private static final Logger logger = LoggerFactory.getLogger(DBMigrator.class);
|
||||||
private final MigrationExecutor migrationExecutor;
|
private final MigrationExecutor migrationExecutor;
|
||||||
private final String resourcePath;
|
private final String resourcePath;
|
||||||
|
|
||||||
public DBMigrator(final MigrationExecutor migrationExecutor, final String resourcePath) {
|
public DBMigrator(MigrationExecutor migrationExecutor, String resourcePath) {
|
||||||
this.migrationExecutor = migrationExecutor;
|
this.migrationExecutor = migrationExecutor;
|
||||||
this.resourcePath = resourcePath;
|
this.resourcePath = resourcePath;
|
||||||
}
|
}
|
||||||
|
@ -19,8 +19,8 @@ public final class DBMigrator implements Runnable {
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
logger.info("Starting db migration");
|
logger.info("Starting db migration");
|
||||||
final ClassLoader classLoader = getClass().getClassLoader();
|
ClassLoader classLoader = getClass().getClassLoader();
|
||||||
try (final InputStream inputStream = classLoader.getResourceAsStream(this.resourcePath)) {
|
try (InputStream inputStream = classLoader.getResourceAsStream(this.resourcePath)) {
|
||||||
if (inputStream == null) {
|
if (inputStream == null) {
|
||||||
logger.error("Migration file not found: " + resourcePath);
|
logger.error("Migration file not found: " + resourcePath);
|
||||||
throw new RuntimeException("Migration file not found");
|
throw new RuntimeException("Migration file not found");
|
||||||
|
|
|
@ -11,24 +11,24 @@ import java.sql.DriverManager;
|
||||||
import java.sql.SQLException;
|
import java.sql.SQLException;
|
||||||
import java.sql.Statement;
|
import java.sql.Statement;
|
||||||
|
|
||||||
public final class MigrationExecutor {
|
public class MigrationExecutor {
|
||||||
private static final Logger logger = LoggerFactory.getLogger(MigrationExecutor.class);
|
private static final Logger logger = LoggerFactory.getLogger(MigrationExecutor.class);
|
||||||
private final String jdbcConnectionString;
|
private final String jdbcConnectionString;
|
||||||
private final String username;
|
private final String username;
|
||||||
private final String password;
|
private final String password;
|
||||||
|
|
||||||
public MigrationExecutor(final String jdbcConnectionString, final String username, final String password) {
|
public MigrationExecutor(String jdbcConnectionString, String username, String password) {
|
||||||
this.jdbcConnectionString = jdbcConnectionString;
|
this.jdbcConnectionString = jdbcConnectionString;
|
||||||
this.username = username;
|
this.username = username;
|
||||||
this.password = password;
|
this.password = password;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void migrate(final InputStream input) throws DBMigrationException {
|
public void migrate(InputStream input) throws DBMigrationException {
|
||||||
try (final Connection con = DriverManager
|
try (Connection con = DriverManager
|
||||||
.getConnection(this.jdbcConnectionString, this.username, this.password);
|
.getConnection(this.jdbcConnectionString, this.username, this.password);
|
||||||
final Statement stmp = con.createStatement();
|
Statement stmp = con.createStatement();
|
||||||
) {
|
) {
|
||||||
final String content = new String(input.readAllBytes(), StandardCharsets.UTF_8);
|
String content = new String(input.readAllBytes(), StandardCharsets.UTF_8);
|
||||||
stmp.execute(content);
|
stmp.execute(content);
|
||||||
} catch (SQLException | IOException e) {
|
} catch (SQLException | IOException e) {
|
||||||
throw new DBMigrationException(e);
|
throw new DBMigrationException(e);
|
||||||
|
|
|
@ -12,6 +12,10 @@
|
||||||
<version>1.0-SNAPSHOT</version>
|
<version>1.0-SNAPSHOT</version>
|
||||||
<name>web</name>
|
<name>web</name>
|
||||||
|
|
||||||
|
<properties>
|
||||||
|
<node.version>v18.16.0</node.version>
|
||||||
|
</properties>
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>de.hhhammer.dchat</groupId>
|
<groupId>de.hhhammer.dchat</groupId>
|
||||||
|
|
|
@ -2,41 +2,41 @@ package de.hhhammer.dchat.web;
|
||||||
|
|
||||||
import com.zaxxer.hikari.HikariConfig;
|
import com.zaxxer.hikari.HikariConfig;
|
||||||
import com.zaxxer.hikari.HikariDataSource;
|
import com.zaxxer.hikari.HikariDataSource;
|
||||||
import de.hhhammer.dchat.db.PostgresServerDBService;
|
import de.hhhammer.dchat.db.ServerDBService;
|
||||||
import de.hhhammer.dchat.db.PostgresUserDBService;
|
import de.hhhammer.dchat.db.UserDBService;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Hello world!
|
* Hello world!
|
||||||
*/
|
*/
|
||||||
public final class App {
|
public class App {
|
||||||
private static final Logger logger = LoggerFactory.getLogger(App.class);
|
private static final Logger logger = LoggerFactory.getLogger(App.class);
|
||||||
|
|
||||||
public static void main(final String[] args) {
|
public static void main(String[] args) {
|
||||||
final String postgresUser = System.getenv("POSTGRES_USER");
|
String postgresUser = System.getenv("POSTGRES_USER");
|
||||||
final String postgresPassword = System.getenv("POSTGRES_PASSWORD");
|
String postgresPassword = System.getenv("POSTGRES_PASSWORD");
|
||||||
final String postgresUrl = System.getenv("POSTGRES_URL");
|
String postgresUrl = System.getenv("POSTGRES_URL");
|
||||||
if (postgresUser == null || postgresPassword == null || postgresUrl == null) {
|
if (postgresUser == null || postgresPassword == null || postgresUrl == null) {
|
||||||
logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL");
|
logger.error("Missing environment variables: POSTGRES_USER and/or POSTGRES_PASSWORD and/or POSTGRES_URL");
|
||||||
System.exit(1);
|
System.exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
final String apiPortStr = System.getenv("API_PORT") != null ? System.getenv("API_PORT") : "8080";
|
String apiPortStr = System.getenv("API_PORT") != null ? System.getenv("API_PORT") : "8080";
|
||||||
final int apiPort = Integer.parseInt(apiPortStr);
|
int apiPort = Integer.parseInt(apiPortStr);
|
||||||
final boolean debug = "true".equals(System.getenv("API_DEBUG"));
|
boolean debug = "true".equals(System.getenv("API_DEBUG"));
|
||||||
|
|
||||||
final var config = new HikariConfig();
|
var config = new HikariConfig();
|
||||||
config.setJdbcUrl(postgresUrl);
|
config.setJdbcUrl(postgresUrl);
|
||||||
config.setUsername(postgresUser);
|
config.setUsername(postgresUser);
|
||||||
config.setPassword(postgresPassword);
|
config.setPassword(postgresPassword);
|
||||||
|
|
||||||
try (final var ds = new HikariDataSource(config)) {
|
try (var ds = new HikariDataSource(config)) {
|
||||||
final var serverDBService = new PostgresServerDBService(ds);
|
var serverDBService = new ServerDBService(ds);
|
||||||
final var userDBService = new PostgresUserDBService(ds);
|
var userDBService = new UserDBService(ds);
|
||||||
final var appConfig = new AppConfig(apiPort, debug);
|
var appConfig = new AppConfig(apiPort, debug);
|
||||||
|
|
||||||
final var webApi = new WebAPI(serverDBService, userDBService, appConfig);
|
var webApi = new WebAPI(serverDBService, userDBService, appConfig);
|
||||||
webApi.run();
|
webApi.run();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ import de.hhhammer.dchat.web.server.ConfigCrudHandler;
|
||||||
import de.hhhammer.dchat.web.user.ConfigUserCrudHandler;
|
import de.hhhammer.dchat.web.user.ConfigUserCrudHandler;
|
||||||
import io.javalin.Javalin;
|
import io.javalin.Javalin;
|
||||||
import io.javalin.http.HttpStatus;
|
import io.javalin.http.HttpStatus;
|
||||||
|
import io.javalin.http.staticfiles.Location;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
@ -15,13 +16,13 @@ import java.util.concurrent.ExecutionException;
|
||||||
import static io.javalin.apibuilder.ApiBuilder.crud;
|
import static io.javalin.apibuilder.ApiBuilder.crud;
|
||||||
import static io.javalin.apibuilder.ApiBuilder.path;
|
import static io.javalin.apibuilder.ApiBuilder.path;
|
||||||
|
|
||||||
public final class WebAPI implements Runnable {
|
public class WebAPI implements Runnable {
|
||||||
private static final Logger logger = LoggerFactory.getLogger(WebAPI.class);
|
private static final Logger logger = LoggerFactory.getLogger(WebAPI.class);
|
||||||
private final ServerDBService serverDBService;
|
private final ServerDBService serverDBService;
|
||||||
private final UserDBService userDBService;
|
private final UserDBService userDBService;
|
||||||
private final AppConfig appConfig;
|
private final AppConfig appConfig;
|
||||||
|
|
||||||
public WebAPI(final ServerDBService serverDBService, final UserDBService userDBService, final AppConfig appConfig) {
|
public WebAPI(ServerDBService serverDBService, UserDBService userDBService, AppConfig appConfig) {
|
||||||
this.serverDBService = serverDBService;
|
this.serverDBService = serverDBService;
|
||||||
this.userDBService = userDBService;
|
this.userDBService = userDBService;
|
||||||
this.appConfig = appConfig;
|
this.appConfig = appConfig;
|
||||||
|
@ -30,12 +31,12 @@ public final class WebAPI implements Runnable {
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
logger.info("Starting web application");
|
logger.info("Starting web application");
|
||||||
final Javalin app = Javalin.create(config -> {
|
var app = Javalin.create(config -> {
|
||||||
if (appConfig.debug()) config.plugins.enableDevLogging();
|
if (appConfig.debug()) config.plugins.enableDevLogging();
|
||||||
config.http.prefer405over404 = true; // return 405 instead of 404 if path is mapped to different HTTP method
|
config.http.prefer405over404 = true; // return 405 instead of 404 if path is mapped to different HTTP method
|
||||||
config.http.defaultContentType = "application/json";
|
config.http.defaultContentType = "application/json";
|
||||||
});
|
});
|
||||||
final var waitForShutdown = new CompletableFuture<Void>();
|
var waitForShutdown = new CompletableFuture<Void>();
|
||||||
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
||||||
logger.info("Shutting down web application");
|
logger.info("Shutting down web application");
|
||||||
app.stop();
|
app.stop();
|
||||||
|
|
|
@ -10,21 +10,18 @@ import org.jetbrains.annotations.NotNull;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.util.List;
|
public class ConfigCrudHandler implements CrudHandler {
|
||||||
import java.util.Optional;
|
|
||||||
|
|
||||||
public final class ConfigCrudHandler implements CrudHandler {
|
|
||||||
private static final Logger logger = LoggerFactory.getLogger(ConfigCrudHandler.class);
|
private static final Logger logger = LoggerFactory.getLogger(ConfigCrudHandler.class);
|
||||||
|
|
||||||
private final ServerDBService serverDBService;
|
private final ServerDBService serverDBService;
|
||||||
|
|
||||||
public ConfigCrudHandler(final ServerDBService serverDBService) {
|
public ConfigCrudHandler(ServerDBService serverDBService) {
|
||||||
this.serverDBService = serverDBService;
|
this.serverDBService = serverDBService;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void create(@NotNull final Context context) {
|
public void create(@NotNull Context context) {
|
||||||
final ServerConfig.NewServerConfig body = context.bodyAsClass(ServerConfig.NewServerConfig.class);
|
var body = context.bodyAsClass(ServerConfig.NewServerConfig.class);
|
||||||
try {
|
try {
|
||||||
this.serverDBService.addConfig(body);
|
this.serverDBService.addConfig(body);
|
||||||
} catch (DBException e) {
|
} catch (DBException e) {
|
||||||
|
@ -35,10 +32,9 @@ public final class ConfigCrudHandler implements CrudHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void delete(@NotNull final Context context, @NotNull final String s) {
|
public void delete(@NotNull Context context, @NotNull String s) {
|
||||||
final var id = Long.parseLong(s);
|
|
||||||
try {
|
try {
|
||||||
this.serverDBService.deleteConfig(id);
|
this.serverDBService.deleteConfig(Long.parseLong(s));
|
||||||
context.status(HttpStatus.NO_CONTENT);
|
context.status(HttpStatus.NO_CONTENT);
|
||||||
} catch (DBException e) {
|
} catch (DBException e) {
|
||||||
logger.error("Deleting configuration with id: " + s, e);
|
logger.error("Deleting configuration with id: " + s, e);
|
||||||
|
@ -47,9 +43,9 @@ public final class ConfigCrudHandler implements CrudHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void getAll(@NotNull final Context context) {
|
public void getAll(@NotNull Context context) {
|
||||||
try {
|
try {
|
||||||
final List<ServerConfig> allowedServers = this.serverDBService.getAllConfigs();
|
var allowedServers = this.serverDBService.getAllConfigs();
|
||||||
context.json(allowedServers);
|
context.json(allowedServers);
|
||||||
} catch (DBException e) {
|
} catch (DBException e) {
|
||||||
logger.error("Getting all server configs", e);
|
logger.error("Getting all server configs", e);
|
||||||
|
@ -58,10 +54,10 @@ public final class ConfigCrudHandler implements CrudHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void getOne(@NotNull final Context context, @NotNull final String s) {
|
public void getOne(@NotNull Context context, @NotNull String s) {
|
||||||
final var id = Long.parseLong(s);
|
var id = Long.parseLong(s);
|
||||||
try {
|
try {
|
||||||
final Optional<ServerConfig> server = this.serverDBService.getConfigBy(id);
|
var server = this.serverDBService.getConfigBy(id);
|
||||||
if (server.isEmpty()) {
|
if (server.isEmpty()) {
|
||||||
context.status(HttpStatus.NOT_FOUND);
|
context.status(HttpStatus.NOT_FOUND);
|
||||||
return;
|
return;
|
||||||
|
@ -74,9 +70,9 @@ public final class ConfigCrudHandler implements CrudHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void update(@NotNull final Context context, @NotNull final String idString) {
|
public void update(@NotNull Context context, @NotNull String idString) {
|
||||||
final ServerConfig.NewServerConfig body = context.bodyAsClass(ServerConfig.NewServerConfig.class);
|
var body = context.bodyAsClass(ServerConfig.NewServerConfig.class);
|
||||||
final var id = Long.parseLong(idString);
|
var id = Long.parseLong(idString);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
this.serverDBService.updateConfig(id, body);
|
this.serverDBService.updateConfig(id, body);
|
||||||
|
|
|
@ -10,21 +10,18 @@ import org.jetbrains.annotations.NotNull;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.util.List;
|
public class ConfigUserCrudHandler implements CrudHandler {
|
||||||
import java.util.Optional;
|
|
||||||
|
|
||||||
public final class ConfigUserCrudHandler implements CrudHandler {
|
|
||||||
private static final Logger logger = LoggerFactory.getLogger(ConfigUserCrudHandler.class);
|
private static final Logger logger = LoggerFactory.getLogger(ConfigUserCrudHandler.class);
|
||||||
|
|
||||||
private final UserDBService userDBService;
|
private final UserDBService userDBService;
|
||||||
|
|
||||||
public ConfigUserCrudHandler(final UserDBService userDBService) {
|
public ConfigUserCrudHandler(UserDBService userDBService) {
|
||||||
this.userDBService = userDBService;
|
this.userDBService = userDBService;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void create(@NotNull final Context context) {
|
public void create(@NotNull Context context) {
|
||||||
final UserConfig.NewUserConfig body = context.bodyAsClass(UserConfig.NewUserConfig.class);
|
var body = context.bodyAsClass(UserConfig.NewUserConfig.class);
|
||||||
try {
|
try {
|
||||||
this.userDBService.addConfig(body);
|
this.userDBService.addConfig(body);
|
||||||
} catch (DBException e) {
|
} catch (DBException e) {
|
||||||
|
@ -36,10 +33,9 @@ public final class ConfigUserCrudHandler implements CrudHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void delete(@NotNull final Context context, @NotNull final String s) {
|
public void delete(@NotNull Context context, @NotNull String s) {
|
||||||
final var id = Long.parseLong(s);
|
|
||||||
try {
|
try {
|
||||||
this.userDBService.deleteConfig(id);
|
this.userDBService.deleteConfig(Long.parseLong(s));
|
||||||
context.status(HttpStatus.NO_CONTENT);
|
context.status(HttpStatus.NO_CONTENT);
|
||||||
} catch (DBException e) {
|
} catch (DBException e) {
|
||||||
logger.error("Deleting configuration with id: " + s, e);
|
logger.error("Deleting configuration with id: " + s, e);
|
||||||
|
@ -48,9 +44,9 @@ public final class ConfigUserCrudHandler implements CrudHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void getAll(@NotNull final Context context) {
|
public void getAll(@NotNull Context context) {
|
||||||
try {
|
try {
|
||||||
final List<UserConfig> allowedServers = this.userDBService.getAllConfigs();
|
var allowedServers = this.userDBService.getAllConfigs();
|
||||||
context.json(allowedServers);
|
context.json(allowedServers);
|
||||||
} catch (DBException e) {
|
} catch (DBException e) {
|
||||||
logger.error("Getting all user configs", e);
|
logger.error("Getting all user configs", e);
|
||||||
|
@ -59,10 +55,10 @@ public final class ConfigUserCrudHandler implements CrudHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void getOne(@NotNull final Context context, @NotNull final String s) {
|
public void getOne(@NotNull Context context, @NotNull String s) {
|
||||||
final var id = Long.parseLong(s);
|
var id = Long.parseLong(s);
|
||||||
try {
|
try {
|
||||||
final Optional<UserConfig> server = this.userDBService.getConfigBy(id);
|
var server = this.userDBService.getConfigBy(id);
|
||||||
if (server.isEmpty()) {
|
if (server.isEmpty()) {
|
||||||
context.status(HttpStatus.NOT_FOUND);
|
context.status(HttpStatus.NOT_FOUND);
|
||||||
return;
|
return;
|
||||||
|
@ -76,8 +72,8 @@ public final class ConfigUserCrudHandler implements CrudHandler {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void update(@NotNull Context context, @NotNull String idString) {
|
public void update(@NotNull Context context, @NotNull String idString) {
|
||||||
final UserConfig.NewUserConfig body = context.bodyAsClass(UserConfig.NewUserConfig.class);
|
var body = context.bodyAsClass(UserConfig.NewUserConfig.class);
|
||||||
final var id = Long.parseLong(idString);
|
var id = Long.parseLong(idString);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
this.userDBService.updateConfig(id, body);
|
this.userDBService.updateConfig(id, body);
|
||||||
|
|
Loading…
Reference in a new issue