Java - 로지스틱 회귀(Logistic Regression)
프로젝트 진행 중에 Java를 활용한 로지스틱 회귀 관련 알고리즘을 발견해서 사용해 본 후 해당 내용을
공유해봅니다.
로지스틱 회귀는 주로 반응변수가 주로 이진형 값에서 주로 사용되는 회귀 분석 방법이라고 정의할 수
있습니다.
해석의 편의성과 샘플링 데이터의 계수 추정치를 편하게 계산할 수 있는 관계로 널리 사용되고
있는 알고리즘이라고 생각이 됩니다.
로지스틱 회귀 알고리즘에 대한 상세한 설명이 나와있는 블로그를 하단에 소개해 드립니다.
아울러, 아래 소스코드를 하단에 공유해 드리니 테스트가 필요하신 분은 다운로드 후 사용하시면 됩니다.
로지스틱 회귀(Logistic Regression) 설명 추천 블로그 : 클릭 이동
로지스틱 회귀(Logistic Regression) 설명 추천 블로그 : 클릭 이동
Java 로지스틱 회귀(Logistic Regression) - 소스 코드
하단에 예제를 통해서 로지스틱 회귀 학습 및 결과 예제를 쉽게 사용해 볼 수 있습니다.
main.java
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 package java_lr; public class LogisticRegression { public int N; public int n_in; public int n_out; public double[][] W; public double[] b; public LogisticRegression(int N, int n_in, int n_out) { this.N = N; this.n_in = n_in; this.n_out = n_out; W = new double[n_out][n_in]; b = new double[n_out]; } public double[] train(double[] x, int[] y, double lr) { double[] p_y_given_x = new double[n_out]; double[] dy = new double[n_out]; for(int i=0; i<n_out; i++) { p_y_given_x[i] = 0; for(int j=0; j<n_in; j++) { p_y_given_x[i] += W[i][j] * x[j]; } p_y_given_x[i] += b[i]; } softmax(p_y_given_x); for(int i=0; i<n_out; i++) { dy[i] = y[i] - p_y_given_x[i]; for(int j=0; j<n_in; j++) { W[i][j] += lr * dy[i] * x[j] / N; } b[i] += lr * dy[i] / N; } return dy; } public void softmax(double[] x) { double max = 0.0; double sum = 0.0; for(int i=0; i<n_out; i++) if(max < x[i]) max = x[i]; for(int i=0; i<n_out; i++) { x[i] = Math.exp(x[i] - max); sum += x[i]; } for(int i=0; i<n_out; i++) x[i] /= sum; } public void predict(double[] x, double[] y) { for(int i=0; i<n_out; i++) { y[i] = 0.; for(int j=0; j<n_in; j++) { y[i] += W[i][j] * x[j]; } y[i] += b[i]; } softmax(y); } private static void test_lr() { double learning_rate = 0.1; int n_epochs = 500; int train_N = 6; int test_N = 2; int n_in = 6; int n_out = 2; double[][] train_X = { {1., 1., 1., 0., 0., 0.}, {1., 0., 1., 0., 0., 0.}, {1., 1., 1., 0., 0., 0.}, {0., 0., 1., 1., 1., 0.}, {0., 0., 1., 1., 0., 0.}, {0., 0., 1., 1., 1., 0.} }; int[][] train_Y = { {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1} }; // construct LogisticRegression classifier = new LogisticRegression(train_N, n_in, n_out); // train for(int epoch=0; epoch<n_epochs; epoch++) { for(int i=0; i<train_N; i++) { classifier.train(train_X[i], train_Y[i], learning_rate); } //learning_rate *= 0.95; } // test data double[][] test_X = { {1., 0., 1., 0., 0., 0.}, {0., 0., 1., 1., 1., 0.} }; double[][] test_Y = new double[test_N][n_out]; // test for(int i=0; i<test_N; i++) { classifier.predict(test_X[i], test_Y[i]); for(int j=0; j<n_out; j++) { System.out.print(test_Y[i][j] + " "); } System.out.println(); } } public static void main(String[] args) { test_lr(); }} cs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | package java_lr; public class LogisticRegression { public int N; public int n_in; public int n_out; public double[][] W; public double[] b; public LogisticRegression(int N, int n_in, int n_out) { this.N = N; this.n_in = n_in; this.n_out = n_out; W = new double[n_out][n_in]; b = new double[n_out]; } public double[] train(double[] x, int[] y, double lr) { double[] p_y_given_x = new double[n_out]; double[] dy = new double[n_out]; for(int i=0; i<n_out; i++) { p_y_given_x[i] = 0; for(int j=0; j<n_in; j++) { p_y_given_x[i] += W[i][j] * x[j]; } p_y_given_x[i] += b[i]; } softmax(p_y_given_x); for(int i=0; i<n_out; i++) { dy[i] = y[i] - p_y_given_x[i]; for(int j=0; j<n_in; j++) { W[i][j] += lr * dy[i] * x[j] / N; } b[i] += lr * dy[i] / N; } return dy; } public void softmax(double[] x) { double max = 0.0; double sum = 0.0; for(int i=0; i<n_out; i++) if(max < x[i]) max = x[i]; for(int i=0; i<n_out; i++) { x[i] = Math.exp(x[i] - max); sum += x[i]; } for(int i=0; i<n_out; i++) x[i] /= sum; } public void predict(double[] x, double[] y) { for(int i=0; i<n_out; i++) { y[i] = 0.; for(int j=0; j<n_in; j++) { y[i] += W[i][j] * x[j]; } y[i] += b[i]; } softmax(y); } private static void test_lr() { double learning_rate = 0.1; int n_epochs = 500; int train_N = 6; int test_N = 2; int n_in = 6; int n_out = 2; double[][] train_X = { {1., 1., 1., 0., 0., 0.}, {1., 0., 1., 0., 0., 0.}, {1., 1., 1., 0., 0., 0.}, {0., 0., 1., 1., 1., 0.}, {0., 0., 1., 1., 0., 0.}, {0., 0., 1., 1., 1., 0.} }; int[][] train_Y = { {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1} }; // construct LogisticRegression classifier = new LogisticRegression(train_N, n_in, n_out); // train for(int epoch=0; epoch<n_epochs; epoch++) { for(int i=0; i<train_N; i++) { classifier.train(train_X[i], train_Y[i], learning_rate); } //learning_rate *= 0.95; } // test data double[][] test_X = { {1., 0., 1., 0., 0., 0.}, {0., 0., 1., 1., 1., 0.} }; double[][] test_Y = new double[test_N][n_out]; // test for(int i=0; i<test_N; i++) { classifier.predict(test_X[i], test_Y[i]); for(int j=0; j<n_out; j++) { System.out.print(test_Y[i][j] + " "); } System.out.println(); } } public static void main(String[] args) { test_lr(); } } | cs |
Java 로지스틱 회귀 Github 주소 및 레퍼런스 - 다운로드
Java 로지스틱 회귀 Github 주소 및 레퍼런스 - 다운로드
Java 로지스틱 회귀(Logistic Regression) - 실제 실행 화면
아래 이미지로 실제 실행 화면을 확인하실 수 있습니다.
실제 실행 화면
마무리
이번 포스팅에서는 Java 기반 활용한 간단한 로지스틱 회귀(Logistic Regression) 소스코드를 공유해
보았습니다.
범주형 변수를 예측하는 부분 및 사건의 발생 가능성을 예측하는 분야에 있어서 필드에서 상당히 많이
쓰이고 있는 알고리즘인 만큼 이론에 입각한 정확한 지식을 익히시는 것이 중요할 것이라 생각이 됩니다.
아울러, 파이썬 뿐만 아니라 Java, R 으로 다양한 기본 데이터 수집 후 관련 알고리즘으로 데이터 분석을
실시해보시는 것도 실력향상에 상당한 도움이 될 것으로 생각됩니다.
다음 포스팅을 기약하며 이만 마무리하겠습니다.
소스코드 다운로드 : java_logisticRg.zip
'언어 > Java' 카테고리의 다른 글
Java(자바) 알고리즘 - 나이브베이지안(Naive Bayesian) 오픈소스 및 예제 (0) | 2018.10.15 |
---|---|
Java(자바) 디자인패턴 - 팩토리(Factory Method) 패턴 설명 및 예제소스 (7) | 2018.05.21 |
Java(자바) 디자인패턴 - 템플릿 메소드(Template Method) 패턴 설명 및 예제소스 (1) | 2018.05.18 |
Java(자바) 디자인패턴 - 어댑터(Adapter) 패턴 설명 및 예제소스 (0) | 2018.05.15 |
Java(자바) 디자인패턴 - 전략(Strategy) 패턴 설명 및 예제소스 (3) | 2018.04.18 |