언어/Java

Java(자바) 알고리즘 - 로지스틱회귀(Logistic Regression) 오픈소스 및 예제

[좋은사람] 2018. 10. 18. 10:27



Java - 로지스틱 회귀(Logistic Regression)



프로젝트 진행 중에 Java를 활용한 로지스틱 회귀 관련 알고리즘을 발견해서 사용해 본 후 해당 내용을
공유해봅니다.

로지스틱 회귀는 주로 반응변수가 주로 이진형 값에서 주로 사용되는 회귀 분석 방법이라고 정의할 수
있습니다.

해석의 편의성과 샘플링 데이터의 계수 추정치를 편하게 계산할 수 있는 관계로 널리 사용되고
있는 알고리즘이라고 생각이 됩니다.

로지스틱 회귀 알고리즘에 대한 상세한 설명이 나와있는 블로그를 하단에 소개해 드립니다.

아울러, 아래 소스코드를 하단에 공유해 드리니 테스트가 필요하신 분은 다운로드 후 사용하시면 됩니다. 

로지스틱 회귀(Logistic Regression) 설명 추천 블로그 :  릭 이동



Java 로지스틱 회귀(Logistic Regression) -  소스 코드


하단에 예제를 통해서 로지스틱 회귀 학습 및 결과 예제를 쉽게 사용해 볼 수 있습니다.

main.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package java_lr;
 
public class LogisticRegression {
    public int N;
    public int n_in;
    public int n_out;
    public double[][] W;
    public double[] b;
 
    public LogisticRegression(int N, int n_in, int n_out) {
        this.N = N;
        this.n_in = n_in;
        this.n_out = n_out;
 
        W = new double[n_out][n_in];
        b = new double[n_out];
    }
 
    public double[] train(double[] x, int[] y, double lr) {
        double[] p_y_given_x = new double[n_out];
        double[] dy = new double[n_out];
 
        for(int i=0; i<n_out; i++) {
            p_y_given_x[i] = 0;
            for(int j=0; j<n_in; j++) {
                p_y_given_x[i] += W[i][j] * x[j];
            }
            p_y_given_x[i] += b[i];
        }
        softmax(p_y_given_x);
 
        for(int i=0; i<n_out; i++) {
            dy[i] = y[i] - p_y_given_x[i];
 
            for(int j=0; j<n_in; j++) {
                W[i][j] += lr * dy[i] * x[j] / N;
            }
 
            b[i] += lr * dy[i] / N;
        }
 
        return dy;
    }
 
    public void softmax(double[] x) {
        double max = 0.0;
        double sum = 0.0;
 
        for(int i=0; i<n_out; i++if(max < x[i]) max = x[i];
 
        for(int i=0; i<n_out; i++) {
            x[i] = Math.exp(x[i] - max);
            sum += x[i];
        }
 
        for(int i=0; i<n_out; i++) x[i] /= sum;
    }
 
    public void predict(double[] x, double[] y) {
        for(int i=0; i<n_out; i++) {
            y[i] = 0.;
            for(int j=0; j<n_in; j++) {
                y[i] += W[i][j] * x[j];
            }
            y[i] += b[i];
        }
 
        softmax(y);
    }
 
    private static void test_lr() {
        double learning_rate = 0.1;
        int n_epochs = 500;
 
        int train_N = 6;
        int test_N = 2;
        int n_in = 6;
        int n_out = 2;
 
        double[][] train_X = {
                {1.1.1.0.0.0.},
                {1.0.1.0.0.0.},
                {1.1.1.0.0.0.},
                {0.0.1.1.1.0.},
                {0.0.1.1.0.0.},
                {0.0.1.1.1.0.}
        };
 
        int[][] train_Y = {
                {10},
                {10},
                {10},
                {01},
                {01},
                {01}
        };
 
        // construct
        LogisticRegression classifier = new LogisticRegression(train_N, n_in, n_out);
 
        // train
        for(int epoch=0; epoch<n_epochs; epoch++) {
            for(int i=0; i<train_N; i++) {
                classifier.train(train_X[i], train_Y[i], learning_rate);
            }
            //learning_rate *= 0.95;
        }
 
        // test data
        double[][] test_X = {
                {1.0.1.0.0.0.},
                {0.0.1.1.1.0.}
        };
 
        double[][] test_Y = new double[test_N][n_out];
 
 
        // test
        for(int i=0; i<test_N; i++) {
            classifier.predict(test_X[i], test_Y[i]);
            for(int j=0; j<n_out; j++) {
                System.out.print(test_Y[i][j] + " ");
            }
            System.out.println();
        }
    }
 
    public static void main(String[] args) {
        test_lr();
    }
}
cs

Java 로지스틱 회귀 Github 주소 및 레퍼런스 - 다운로드 



Java 로지스틱 회귀(Logistic Regression) - 실제 실행 화면 


아래 이미지로 실제 실행 화면을 확인하실 수 있습니다.



실제 실행 화면



마무리


이번 포스팅에서는 Java 기반 활용한 간단한 로지스틱 회귀(Logistic Regression) 소스코드를 공유해 
보았습니다.

범주형 변수를 예측하는 부분 및 사건의 발생 가능성을 예측하는 분야에 있어서 필드에서 상당히 많이
쓰이고 있는 알고리즘인 만큼 이론에 입각한 정확한 지식을 익히시는 것이 중요할 것이라 생각이 됩니다.

아울러, 파이썬 뿐만 아니라 Java, R 으로 다양한 기본 데이터 수집 후 관련 알고리즘으로 데이터 분석을 
실시해보시는 것도 실력향상에 상당한 도움이 될 것으로 생각됩니다.


다음 포스팅을 기약하며 이만 마무리하겠습니다.

소스코드 다운로드 :  java_logisticRg.zip