Compare commits

..

No commits in common. "4b60f599566647ccb64f2c853130d559c586b48a" and "f5e40e324d0ed6dab8bb9a45bf79fb3165504976" have entirely different histories.

14 changed files with 72 additions and 176 deletions

View file

@ -18,23 +18,16 @@
<artifactId>db</artifactId>
<version>1.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>com.zaxxer</groupId>
<artifactId>HikariCP</artifactId>
</dependency>
<dependency>
<groupId>org.javacord</groupId>
<artifactId>javacord</artifactId>
<version>3.8.0</version>
<type>pom</type>
</dependency>
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-to-slf4j</artifactId>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.20.0</version>
</dependency>
</dependencies>

View file

@ -1,7 +1,5 @@
package de.hhhammer.dchat.bot;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import de.hhhammer.dchat.bot.openai.ChatGPTService;
import de.hhhammer.dchat.db.ServerDBService;
import de.hhhammer.dchat.db.UserDBService;
@ -33,18 +31,10 @@ public class App {
}
var chatGPTService = new ChatGPTService(openaiApiKey, HttpClient.newHttpClient());
var serverDBService = new ServerDBService(postgresUrl, postgresUser, postgresPassword);
var userDBService = new UserDBService(postgresUrl, postgresUser, postgresPassword);
var config = new HikariConfig();
config.setJdbcUrl(postgresUrl);
config.setUsername(postgresUser);
config.setPassword(postgresPassword);
try (var ds = new HikariDataSource(config)) {
var serverDBService = new ServerDBService(ds);
var userDBService = new UserDBService(ds);
var discordBot = new DiscordBot(serverDBService, userDBService, chatGPTService, discordApiKey);
discordBot.run();
}
var discordBot = new DiscordBot(serverDBService, userDBService, chatGPTService, discordApiKey);
discordBot.run();
}
}

View file

@ -12,8 +12,6 @@ import org.javacord.api.interaction.SlashCommand;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
public class DiscordBot implements Runnable {
@ -38,10 +36,9 @@ public class DiscordBot implements Runnable {
.setToken(discordApiKey)
.login()
.join();
var future = new CompletableFuture<Void>();
Runtime.getRuntime().addShutdownHook(Thread.ofVirtual().unstarted(() -> {
logger.info("Shutting down Discord application");
discordApi.disconnect().thenAccept(future::complete);
discordApi.disconnect();
}));
var token = SlashCommand.with("tokens", "Check how many tokens where spend on this server")
.createGlobal(discordApi)
@ -72,11 +69,5 @@ public class DiscordBot implements Runnable {
// Print the invite url of your bot
logger.info("You can invite the bot by using the following url: " + discordApi.createBotInvite());
try {
future.get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}
}

View file

@ -33,7 +33,7 @@ public class ServerMessageHandler implements MessageHandler {
var systemMessage = this.serverDBService.getConfig(String.valueOf(serverId)).get().systemMessage();
var request = event.getMessage().getType() == MessageType.REPLY ?
new ChatGPTRequestBuilder()
.replyRequest(event.getMessage()
.contextRequest(event.getMessage()
.getMessageReference()
.map(MessageReference::getMessage)
.flatMap(m -> m)

View file

@ -24,14 +24,9 @@ public class UserMessageHandler implements MessageHandler {
@Override
public void handle(MessageCreateEvent event) throws ResponseException, IOException, InterruptedException {
String content = event.getReadableMessageContent();
var userId = String.valueOf(event.getMessageAuthor().getId());
var config = this.userDBService.getConfig(userId).get();
var systemMessage = config.systemMessage();
var context = this.userDBService.getLastMessages(userId, config.contextLength())
.stream()
.map(userMessage -> new ChatGPTRequestBuilder.PreviousInteraction(userMessage.question(), userMessage.answer()))
.toList();
var request = new ChatGPTRequestBuilder().contextRequest(context, content, systemMessage);
var userId = event.getMessageAuthor().getId();
var systemMessage = this.userDBService.getConfig(String.valueOf(userId)).get().systemMessage();
var request = new ChatGPTRequestBuilder().simpleRequest(content, systemMessage);
var response = this.chatGPTService.submit(request);
if (response.choices().size() < 1) {
event.getMessage().reply("No response available");

View file

@ -1,7 +1,6 @@
package de.hhhammer.dchat.bot.openai;
import de.hhhammer.dchat.bot.openai.models.ChatGPTRequest;
import java.util.ArrayList;
import java.util.List;
@ -19,7 +18,7 @@ public class ChatGPTRequestBuilder {
);
}
public ChatGPTRequest replyRequest(List<String> contextMessages, String message, String systemMessage) {
public ChatGPTRequest contextRequest(List<String> contextMessages, String message, String systemMessage) {
List<ChatGPTRequest.Message> messages = new ArrayList<>();
messages.add(new ChatGPTRequest.Message("system", systemMessage));
var context = contextMessages.stream()
@ -33,26 +32,4 @@ public class ChatGPTRequestBuilder {
0.7f
);
}
public ChatGPTRequest contextRequest(List<PreviousInteraction> contextMessages, String message, String systemMessage) {
List<ChatGPTRequest.Message> messages = new ArrayList<>();
messages.add(new ChatGPTRequest.Message("system", systemMessage));
var context = contextMessages.stream()
.map(m -> List.of(
new ChatGPTRequest.Message("user", m.question),
new ChatGPTRequest.Message("assistant", m.answer)
))
.flatMap(List::stream)
.toList();
messages.addAll(context);
messages.add(new ChatGPTRequest.Message("user", message));
return new ChatGPTRequest(
model,
messages,
0.7f
);
}
public record PreviousInteraction(String question, String answer) {
}
}

View file

@ -29,7 +29,7 @@ public class ChatGPTService {
.POST(HttpRequest.BodyPublishers.ofByteArray(data))
.setHeader("Content-Type", "application/json")
.setHeader("Authorization", "Bearer " + this.apiKey)
.timeout(Duration.ofMinutes(5))
.timeout(Duration.ofSeconds(90))
.build();
var responseStream = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());

View file

@ -17,11 +17,11 @@
<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
<artifactId>jackson-databind</artifactId>
<version>2.15.0-rc2</version>
</dependency>
</dependencies>

View file

@ -5,7 +5,6 @@ import de.hhhammer.dchat.db.models.server.ServerMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.sql.DataSource;
import java.sql.*;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
@ -15,17 +14,22 @@ import java.util.stream.StreamSupport;
public class ServerDBService {
private static final Logger logger = LoggerFactory.getLogger(ServerDBService.class);
private final DataSource dataSource;
private final String jdbcConnectionString;
private final String username;
private final String password;
public ServerDBService(DataSource dataSource) {
this.dataSource = dataSource;
public ServerDBService(String jdbcConnectionString, String username, String password) {
this.jdbcConnectionString = jdbcConnectionString;
this.username = username;
this.password = password;
}
public Optional<ServerConfig> getConfig(String serverId) {
var getServerConfig = """
SELECT * FROM server_configs WHERE server_id = ?
""";
try (Connection con = dataSource.getConnection();
try (Connection con = DriverManager
.getConnection(this.jdbcConnectionString, this.username, this.password);
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, serverId);
@ -45,7 +49,7 @@ public class ServerDBService {
var getAllowedServerSql = """
SELECT * FROM server_configs
""";
try (Connection con = dataSource.getConnection();
try (Connection con = DriverManager.getConnection(this.jdbcConnectionString, this.username, this.password);
PreparedStatement pstmt = con.prepareStatement(getAllowedServerSql)
) {
ResultSet resultSet = pstmt.executeQuery();
@ -62,7 +66,8 @@ public class ServerDBService {
var getServerConfig = """
SELECT * FROM server_configs WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
try (Connection con = DriverManager
.getConnection(this.jdbcConnectionString, this.username, this.password);
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
@ -80,7 +85,8 @@ public class ServerDBService {
var getServerConfig = """
INSERT INTO server_configs (server_id, system_message, rate_limit) VALUES (?,?,?)
""";
try (Connection con = dataSource.getConnection();
try (Connection con = DriverManager
.getConnection(this.jdbcConnectionString, this.username, this.password);
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newServerConfig.serverId());
@ -99,7 +105,8 @@ public class ServerDBService {
var getServerConfig = """
UPDATE server_configs SET system_message = ?, rate_limit = ?, server_id = ? WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
try (Connection con = DriverManager
.getConnection(this.jdbcConnectionString, this.username, this.password);
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newServerConfig.systemMessage());
@ -119,7 +126,8 @@ public class ServerDBService {
var getServerConfig = """
DELETE FROM server_configs WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
try (Connection con = DriverManager
.getConnection(this.jdbcConnectionString, this.username, this.password);
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
@ -136,7 +144,8 @@ public class ServerDBService {
var getServerConfig = """
SELECT count(*) FROM server_messages WHERE server_id = ? AND time <= ? and time >= ?
""";
try (Connection con = dataSource.getConnection();
try (Connection con = DriverManager
.getConnection(this.jdbcConnectionString, this.username, this.password);
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, serverId);
@ -159,7 +168,8 @@ public class ServerDBService {
var getServerConfig = """
INSERT INTO server_messages (server_id, user_id, tokens) VALUES (?,?,?)
""";
try (Connection con = dataSource.getConnection();
try (Connection con = DriverManager
.getConnection(this.jdbcConnectionString, this.username, this.password);
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, serverMessage.serverId());
@ -178,7 +188,8 @@ public class ServerDBService {
var countTokensOfLast30Days = """
SELECT sum(tokens) FROM server_messages WHERE server_id = ? AND time < ? AND time >= ?
""";
try (Connection con = dataSource.getConnection();
try (Connection con = DriverManager
.getConnection(this.jdbcConnectionString, this.username, this.password);
PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
) {
pstmt.setString(1, serverId);

View file

@ -5,7 +5,6 @@ import de.hhhammer.dchat.db.models.user.UserMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.sql.DataSource;
import java.sql.*;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
@ -15,17 +14,22 @@ import java.util.stream.StreamSupport;
public class UserDBService {
private static final Logger logger = LoggerFactory.getLogger(UserDBService.class);
private final DataSource dataSource;
private final String jdbcConnectionString;
private final String username;
private final String password;
public UserDBService(DataSource dataSource) {
this.dataSource = dataSource;
public UserDBService(String jdbcConnectionString, String username, String password) {
this.jdbcConnectionString = jdbcConnectionString;
this.username = username;
this.password = password;
}
public Optional<UserConfig> getConfig(String userId) {
var getServerConfig = """
SELECT * FROM user_configs WHERE user_id = ?
""";
try (Connection con = dataSource.getConnection();
try (Connection con = DriverManager
.getConnection(this.jdbcConnectionString, this.username, this.password);
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, userId);
@ -45,7 +49,8 @@ public class UserDBService {
var getServerConfig = """
SELECT * FROM user_configs WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
try (Connection con = DriverManager
.getConnection(this.jdbcConnectionString, this.username, this.password);
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
@ -63,7 +68,8 @@ public class UserDBService {
var getServerConfig = """
SELECT * FROM user_configs
""";
try (Connection con = dataSource.getConnection();
try (Connection con = DriverManager
.getConnection(this.jdbcConnectionString, this.username, this.password);
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
ResultSet resultSet = pstmt.executeQuery();
@ -80,7 +86,8 @@ public class UserDBService {
var getServerConfig = """
INSERT INTO user_configs (user_id, system_message, context_length, rate_limit) VALUES (?,?,?,?)
""";
try (Connection con = dataSource.getConnection();
try (Connection con = DriverManager
.getConnection(this.jdbcConnectionString, this.username, this.password);
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserConfig.userId());
@ -100,7 +107,8 @@ public class UserDBService {
var getServerConfig = """
UPDATE user_configs SET system_message = ?, context_length = ?, rate_limit = ?, user_id = ? WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
try (Connection con = DriverManager
.getConnection(this.jdbcConnectionString, this.username, this.password);
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserConfig.systemMessage());
@ -121,7 +129,8 @@ public class UserDBService {
var getServerConfig = """
DELETE FROM user_configs WHERE id = ?
""";
try (Connection con = dataSource.getConnection();
try (Connection con = DriverManager
.getConnection(this.jdbcConnectionString, this.username, this.password);
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setLong(1, id);
@ -138,7 +147,8 @@ public class UserDBService {
var getServerConfig = """
SELECT count(*) FROM user_messages WHERE user_id = ? AND time <= ? and time >= ?
""";
try (Connection con = dataSource.getConnection();
try (Connection con = DriverManager
.getConnection(this.jdbcConnectionString, this.username, this.password);
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, userId);
@ -161,7 +171,8 @@ public class UserDBService {
var getServerConfig = """
INSERT INTO user_messages (user_id, question, answer, tokens) VALUES (?,?,?,?)
""";
try (Connection con = dataSource.getConnection();
try (Connection con = DriverManager
.getConnection(this.jdbcConnectionString, this.username, this.password);
PreparedStatement pstmt = con.prepareStatement(getServerConfig)
) {
pstmt.setString(1, newUserMessage.userId());
@ -177,31 +188,12 @@ public class UserDBService {
}
}
public List<UserMessage> getLastMessages(String userId, int limit) {
var getLastMessages = """
SELECT * FROM user_messages WHERE user_id = ? ORDER BY time DESC LIMIT ?
""";
try (Connection con = dataSource.getConnection();
PreparedStatement pstmt = con.prepareStatement(getLastMessages)
) {
pstmt.setString(1, userId);
pstmt.setInt(2, limit);
ResultSet resultSet = pstmt.executeQuery();
Iterable<UserMessage> iterable = () -> new ResultSetIterator<>(resultSet, new UserMessage.UserMessageResultSetTransformer());
return StreamSupport.stream(iterable.spliterator(), false).toList();
} catch (SQLException e) {
logger.error("Fetching last messages for user whit id: " + userId, e);
} catch (ResultSetIteratorException e) {
logger.error("Iterating over messages ResultSet from user with id: " + userId, e);
}
return List.of();
}
public long tokensOfLast30Days(String userId) {
var countTokensOfLast30Days = """
SELECT sum(tokens) FROM user_messages WHERE user_id = ? AND time < ? AND time >= ?
""";
try (Connection con = dataSource.getConnection();
try (Connection con = DriverManager
.getConnection(this.jdbcConnectionString, this.username, this.password);
PreparedStatement pstmt = con.prepareStatement(countTokensOfLast30Days)
) {
pstmt.setString(1, userId);

View file

@ -16,7 +16,6 @@
<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
<scope>runtime</scope>
</dependency>
</dependencies>

38
pom.xml
View file

@ -30,7 +30,6 @@
<maven.compiler.source>19</maven.compiler.source>
<maven.compiler.target>19</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<jackson.version>2.15.1</jackson.version>
</properties>
@ -76,43 +75,6 @@
<version>42.6.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>com.zaxxer</groupId>
<artifactId>HikariCP</artifactId>
<version>5.0.1</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jsr310</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>io.javalin</groupId>
<artifactId>javalin</artifactId>
<version>5.5.0</version>
</dependency>
<dependency>
<groupId>org.javacord</groupId>
<artifactId>javacord</artifactId>
<version>3.8.0</version>
<type>pom</type>
</dependency>
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-to-slf4j</artifactId>
<version>2.20.0</version>
<scope>runtime</scope>
</dependency>
</dependencies>
</dependencyManagement>
<build>

View file

@ -22,21 +22,20 @@
<artifactId>db</artifactId>
<version>1.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>com.zaxxer</groupId>
<artifactId>HikariCP</artifactId>
</dependency>
<dependency>
<groupId>io.javalin</groupId>
<artifactId>javalin</artifactId>
<version>5.5.0</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.15.0-rc2</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jsr310</artifactId>
<version>2.14.2</version>
</dependency>
</dependencies>

View file

@ -1,7 +1,5 @@
package de.hhhammer.dchat.web;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import de.hhhammer.dchat.db.ServerDBService;
import de.hhhammer.dchat.db.UserDBService;
import io.javalin.Javalin;
@ -23,23 +21,12 @@ public class App {
System.exit(1);
}
var serverDBService = new ServerDBService(postgresUrl, postgresUser, postgresPassword);
var userDBService = new UserDBService(postgresUrl, postgresUser, postgresPassword);
String apiPortStr = System.getenv("API_PORT") != null ? System.getenv("API_PORT") : "8080";
int apiPort = Integer.parseInt(apiPortStr);
var config = new HikariConfig();
config.setJdbcUrl(postgresUrl);
config.setUsername(postgresUser);
config.setPassword(postgresPassword);
config.addDataSourceProperty("cachePrepStmts", "true");
config.addDataSourceProperty("prepStmtCacheSize", "250");
config.addDataSourceProperty("prepStmtCacheSqlLimit", "2048");
try (var ds = new HikariDataSource(config)) {
var serverDBService = new ServerDBService(ds);
var userDBService = new UserDBService(ds);
var webApi = new WebAPI(serverDBService, userDBService, apiPort);
webApi.run();
}
var webApi = new WebAPI(serverDBService, userDBService, apiPort);
webApi.run();
}
}