본문 바로가기
AI

DB 데이터를 직접 머신러닝 모델 학습에 활용하는 방법

by david100gom 2025. 3. 3.

📌 DB 데이터를 직접 머신러닝 모델 학습에 활용하는 방법

  1. JDBC를 사용하여 MySQL (또는 다른 DB)에서 데이터 가져오기
  2. 데이터를 Weka, Smile, DL4J 등의 라이브러리를 활용하여 학습
  3. 학습된 모델을 저장 및 예측 API로 제공

🔹 1. Weka를 활용한 Random Forest 학습

📌 1️⃣ DB에서 데이터 가져오기 (MySQL 예제)

먼저, 데이터베이스에서 냉장고 온도 데이터를 가져와서 Weka Instances 형식으로 변환합니다.

import weka.classifiers.trees.RandomForest;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.experiment.InstanceQuery;

public class WekaDBTraining {
    public static void main(String[] args) throws Exception {
        // MySQL 연결 및 데이터 가져오기
        InstanceQuery query = new InstanceQuery();
        query.setDatabaseURL("jdbc:mysql://localhost:3306/sensordb");
        query.setUsername("root");
        query.setPassword("password");
        query.setQuery("SELECT temperature, is_abnormal FROM fridge_data");

        // 데이터를 Weka Instances 형식으로 변환
        Instances data = query.retrieveInstances();
        data.setClassIndex(data.numAttributes() - 1); // 마지막 열을 클래스 설정

        // 랜덤 포레스트 모델 훈련
        RandomForest model = new RandomForest();
        model.buildClassifier(data);

        // 모델 저장
        weka.core.SerializationHelper.write("fridge_model.model", model);
        System.out.println("Model trained and saved successfully!");
    }
}

➡️ DB에서 데이터를 직접 가져와 Weka Random Forest 모델을 학습합니다.


📌 2️⃣ 저장된 모델을 활용한 예측

DB에서 실시간 데이터를 가져와 예측을 수행하는 코드입니다.

import weka.classifiers.trees.RandomForest;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.experiment.InstanceQuery;
import java.io.FileInputStream;
import java.io.ObjectInputStream;

public class WekaDBPredictor {
    public static void main(String[] args) throws Exception {
        // 저장된 모델 로드
        ObjectInputStream in = new ObjectInputStream(new FileInputStream("fridge_model.model"));
        RandomForest model = (RandomForest) in.readObject();
        in.close();

        // 실시간 데이터 가져오기
        InstanceQuery query = new InstanceQuery();
        query.setDatabaseURL("jdbc:mysql://localhost:3306/sensordb");
        query.setUsername("root");
        query.setPassword("password");
        query.setQuery("SELECT temperature FROM fridge_data ORDER BY timestamp DESC LIMIT 1");

        Instances dataset = query.retrieveInstances();
        dataset.setClassIndex(dataset.numAttributes() - 1);

        // 예측 수행
        for (Instance instance : dataset) {
            double prediction = model.classifyInstance(instance);
            System.out.println("Predicted Anomaly: " + (prediction == 1 ? "Yes (Anomaly Detected)" : "No (Normal)"));
        }
    }
}

➡️ DB에서 최근 데이터를 가져와 실시간으로 예측을 수행합니다.


🔹 2. Smile을 활용한 Isolation Forest 학습

Smile은 Java 기반의 머신러닝 라이브러리로, Isolation Forest를 사용하여 이상치를 감지할 수 있습니다.

📌 1️⃣ DB에서 데이터를 불러와 Isolation Forest 학습

import smile.data.DataFrame;
import smile.data.vector.DoubleVector;
import smile.io.JDBC;
import smile.anomaly.IsolationForest;
import java.sql.Connection;
import java.sql.DriverManager;

public class SmileDBAnomaly {
    public static void main(String[] args) throws Exception {
        // MySQL 데이터베이스 연결
        String url = "jdbc:mysql://localhost:3306/sensordb";
        Connection conn = DriverManager.getConnection(url, "root", "password");

        // DB에서 데이터 가져오기
        DataFrame df = JDBC.read(conn, "SELECT temperature FROM fridge_data");

        // 온도 데이터 추출
        double[][] tempData = df.column("temperature").toDoubleArray();

        // Isolation Forest 모델 학습
        IsolationForest iforest = IsolationForest.fit(tempData);

        // 모델 저장
        iforest.save("isolation_forest_model.bin");
        System.out.println("Model trained and saved successfully!");
    }
}

➡️ DB에서 온도 데이터를 가져와 Isolation Forest 모델을 학습합니다.


📌 2️⃣ 실시간 데이터 가져와 이상 감지

import smile.anomaly.IsolationForest;
import smile.io.JDBC;
import smile.data.DataFrame;
import java.sql.Connection;
import java.sql.DriverManager;

public class SmileDBAnomalyPredict {
    public static void main(String[] args) throws Exception {
        // 모델 로드
        IsolationForest model = IsolationForest.load("isolation_forest_model.bin");

        // MySQL 데이터베이스 연결
        String url = "jdbc:mysql://localhost:3306/sensordb";
        Connection conn = DriverManager.getConnection(url, "root", "password");

        // 실시간 데이터 가져오기
        DataFrame df = JDBC.read(conn, "SELECT temperature FROM fridge_data ORDER BY timestamp DESC LIMIT 1");

        // 예측 수행
        double[] score = model.score(df.toArray());
        System.out.println("Anomaly Score: " + score[0]);

        if (score[0] < -0.2) { // 특정 임계값 이하일 경우 이상 감지
            System.out.println("⚠ Anomaly detected!");
        } else {
            System.out.println("✅ Normal data.");
        }
    }
}

➡️ DB에서 최근 온도를 가져와 이상징후를 판별합니다.


🔹 3. DL4J (딥러닝) 기반 LSTM 모델

LSTM 모델을 활용하면 시계열 데이터를 기반으로 이상 감지를 수행할 수 있습니다.

📌 1️⃣ LSTM 모델 학습

import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.io.File;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;

public class LSTMDBTrainer {
    private static final String DB_URL = "jdbc:mysql://localhost:3306/your_database";
    private static final String DB_USER = "your_username";
    private static final String DB_PASSWORD = "your_password";

    public static void main(String[] args) throws Exception {
        // 데이터 로드
        double[][] data = loadDataFromDB();

        if (data == null || data.length == 0) {
            System.err.println("❌ 데이터베이스에서 데이터를 불러오지 못했습니다.");
            return;
        }

        // 입력(x)과 출력(y) 데이터 생성
        double[] x = new double[data.length];
        double[] y = new double[data.length];

        for (int i = 0; i < data.length; i++) {
            x[i] = data[i][0]; // timestamp 값 (필요시 정규화)
            y[i] = data[i][1]; // 측정된 value
        }

        // LSTM 모델 구성
        MultiLayerNetwork model = new MultiLayerNetwork(new NeuralNetConfiguration.Builder()
                .weightInit(WeightInit.XAVIER)
                .list()
                .layer(0, new LSTM.Builder().nIn(1).nOut(50)
                        .activation(Activation.TANH).build())
                .layer(1, new OutputLayer.Builder()
                        .lossFunction(LossFunctions.LossFunction.MSE)
                        .activation(Activation.IDENTITY).nIn(50).nOut(1)
                        .build())
                .build());
        model.init();

        // 데이터 ND4J 변환
        DataSet dataset = new DataSet(
                Nd4j.create(x, new int[]{x.length, 1}),
                Nd4j.create(y, new int[]{y.length, 1})
        );

        // 학습
        model.setListeners(new ScoreIterationListener(10));
        model.fit(dataset);

        // 모델 저장
        File modelFile = new File(System.getProperty("user.dir"), "lstm_model.zip");
        model.save(modelFile);
        System.out.println("✅ LSTM Model saved at: " + modelFile.getAbsolutePath());
    }

    /**
     * 데이터베이스에서 시계열 데이터를 불러오는 함수
     */
    private static double[][] loadDataFromDB() {
        List<double[]> dataList = new ArrayList<>();

        try (Connection conn = DriverManager.getConnection(DB_URL, DB_USER, DB_PASSWORD);
             Statement stmt = conn.createStatement();
             ResultSet rs = stmt.executeQuery("SELECT UNIX_TIMESTAMP(timestamp), value FROM time_series_data ORDER BY timestamp ASC")) {

            while (rs.next()) {
                double timestamp = rs.getDouble(1);  // UNIX Timestamp 변환
                double value = rs.getDouble(2);
                dataList.add(new double[]{timestamp, value});
            }

        } catch (SQLException e) {
            e.printStackTrace();
        }

        return dataList.toArray(new double[0][]);
    }
}

➡️ LSTM 모델을 학습 후 저장하여 시계열 이상 탐지에 활용할 수 있습니다.


📌 결론

방법 설명 추천

Random Forest (Weka) 지도 학습 (이상/정상 레이블 필요) ✅ 레이블이 있는 경우
Isolation Forest (Smile) 비지도 학습 (정상 데이터만 필요) ✅ 레이블 없는 경우
LSTM (DL4J) 시계열 이상 탐지 (온도 변화를 학습) ✅ 시계열 분석 필요

 

728x90

댓글