基于Java实现的一层简单人工神经网络算法示例

网友投稿 588 2023-03-07


基于Java实现的一层简单人工神经网络算法示例

本文实例讲述了基于java实现的一层简单人工神经网络算法。分享给大家供大家参考,具体如下:

先来看看笔者绘制的算法图:

2、数据类

import java.util.Arrays;

public class Data {

double[] vector;

int dimention;

int type;

public double[] getVector() {

return vector;

}

public void http://setVector(double[] vector) {

this.vector = vector;

}

public int getDimention() {

return dimention;

}

public void setDimention(int dimention) {

this.dimention = dimention;

}

public int getType() {

return type;

}

public void setType(int type) {

this.type = type;

}

public Data(double[] vector, int dimention, int type) {

super();

this.vector = vector;

this.dimention = dimention;

this.type = type;

}

public Data() {

}

@Override

public String toString() {

return "Data [vector=" + Arrays.toString(vector) + ", dimention=" + dimention + ", type=" + type + "]";

}

}

3、简单人工神经网络

package cn.edu.hbut.chenjie;

import java.util.ArrayList;

import java.util.List;

import java.util.Random;

import org.jfree.chart.ChartFactory;

import org.jfree.chart.ChartFrame;

import org.jfree.chart.JFreeChart;

import org.jfree.data.xy.DefaultXYDataset;

import org.jfree.ui.RefineryUtilities;

public class ANN2 {

private double eta;//学习率

private int n_iter;//权重向量w[]训练次数

private List exercise;//训练数据集

private double w0 = 0;//阈值

private double x0 = 1;//固定值

private double[] weights;//权重向量,其长度为训练数据维度+1,在本例中数据为2维,故长度为3

private int testSum = 0;//测试数据总数

private int error = 0;//错误次数

DefaultXYDataset xydataset = new DefaultXYDataset();

/**

* 向图表中增加同类型的数据

* @param type 类型

* @param a 所有数据的第一个分量

* @param b 所有数据的第二个分量

*/

public void add(String type,double[] a,double[] b)

{

double[][] data = new double[2][a.length];

for(int i=0;i

{

data[0][i] = a[i];

data[1][i] = b[i];

}

xydataset.addSeries(type, data);

}

/**

* 画图

*/

public void draw()

{

JFreeChart jfreechart = ChartFactory.createScatterPlot("exercise", "x1", "x2", xydataset);

ChartFrame frame = new ChartFrame("训练数据", jfreechart);

frame.pack();

RefineryUtilities.centerFrameOnScreen(frame);

frame.setVisible(true);

}

public static void main(String[] args)

{

ANN2 ann2 = new ANN2(0.001,100);//构造人工神经网络

List exercise = new ArrayList();//构造训练集

//人工模拟1000条训练数据 ,分界线为x2=x1+0.5

for(int i=0;i<1000000;i++)

{

Random rd = new Random();

double x1 = rd.nextDouble();//随机产生一个分量

double x2 = rd.nextDouble();//随机产生另一个分量

double[] da = {x1,x2};//产生数据向量

Data d = new Data(da, 2, x2 > x1+0.5 ? 1 : -1);//构造数据

exercise.add(d);//将训练数据加入训练集

}

int sum1 = 0;//记录类型1的训练记录数

int sum2 = 0;//记录类型-1的训练记录数

for(int i = 0; i < exercise.size(); i++)

{

if(exercise.get(i).getType()==1)

sum1++;

else if(exercise.get(i).getType()==-1)

sum2++;

}

double[] x1 = new double[sum1];

double[] y1 = new double[sum1];

double[] x2 = new double[sum2];

double[] y2 = new double[sum2];

int index1 = 0;

int index2 = 0;

for(int i = 0; i < exercise.size(); i++)

{

if(exercise.get(i).getType()==1)

{

x1[index1] = exercise.get(i).vector[0];

y1[index1++] = exercise.get(i).vector[1];

}

else if(exercise.get(i).getType()==-1)

{

x2[index2] = exercise.get(i).vector[0];

y2[index2++] = exercise.get(i).vector[1];

}

}

ann2.add("1", x1, y1);

ann2.add("-1", x2, y2);

ann2.draw();

ann2.input(exercise);//将训练集输入人工神经网络

ann2.fit();//训练

ann2.showWeigths();//显示权重向量

//人工生成一千条测试数据

for(int i=0;i<10000;i++)

{

Random rd = new Random();

double x1_ = rd.nextDouble();

double x2_ = rd.nextDouble();

double[] da = {x1_,x2_};

Dathttp://a test = new Data(da, 2, x2_ > x1_+0.5 ? 1 : -1);

ann2.predict(test);//测试

}

System.out.println("总共测试" + ann2.testSum + "条数据,有" + ann2.error + "条错误,错误率:" + ann2.error * 1.0 /ann2.testSum * 100 + "%");

}

/**

*

* @param eta 学习率

* @param n_iter 权重分量学习次数

*/

public ANN2(double eta, int n_iter) {

this.eta = eta;

this.n_iter = n_iter;

}

/**

* 输入训练集到人工神经网络

* @param exercise

*/

private void input(List exercise) {

this.exercLrEFbseKise = exercise;//保存训练集

weights = new double[exercise.get(0).dimention + 1];//初始化权重向量,其长度为训练数据维度+1

weights[0] = w0;//权重向量第一个分量为w0

for(int i = 1; i < weights.length; i++)

weights[i] = 0;//其余分量初始化为0

}

private void fit() {

for(int i = 0; i < n_iter; i++)//权重分量调整n_iter次

{

for(int j = 0; j < exercise.size(); j++)//对于训练集中的每条数据进行训练

{

int real_result = exercise.get(j).type;//y

int calculate_result = CalculateResult(exercise.get(j));//y'

double delta0 = eta * (real_result - calculate_result);//计算阈值更新

w0 += delta0;//阈值更新

weights[0] = w0;//更新w[0]

for(int k = 0; k < exercise.get(j).getDimention(); k++)//更新权重向量其它分量

{

double delta = eta * (real_result - calculate_result) * exercise.get(j).vector[k];

//w=*(y-y')*X

weights[k+1] += delta;

//w=w+w

}

}

}

}

private int CalculateResult(Data data) {

double z = w0 * x0;

for(int i = 0; i < data.dimention; i++)

z += data.vector[i] * weights[i+1];

//z=w0x0+w1x1+...+WmXm

//激活函数

if(z>=0)

return 1;

else

return -1;

}

private void showWeigths()

{

for(double w : weights)

System.out.println(w);

}

private void predict(Data data) {

int type = CalculateResult(data);

if(type == data.getType())

{

//System.out.println("预测正确");

}

else

{

//System.out.println("预测错误");

error ++;

}

testSum ++;

}

}

运行结果:

-0.22000000000000017

-0.4416843982815453

0.442444202054685

总共测试10000条数据,有17条错误,错误率:0.16999999999999998%

更多关于java算法相关内容感兴趣的读者可查看本站专题:《Java数据结构与算法教程》、《Java操作DOM节点技巧总结》、《Java文件与目录操作技巧汇总》和《Java缓存操作技巧汇总》

希望本文所述对大家java程序设计有所帮助。

{

data[0][i] = a[i];

data[1][i] = b[i];

}

xydataset.addSeries(type, data);

}

/**

* 画图

*/

public void draw()

{

JFreeChart jfreechart = ChartFactory.createScatterPlot("exercise", "x1", "x2", xydataset);

ChartFrame frame = new ChartFrame("训练数据", jfreechart);

frame.pack();

RefineryUtilities.centerFrameOnScreen(frame);

frame.setVisible(true);

}

public static void main(String[] args)

{

ANN2 ann2 = new ANN2(0.001,100);//构造人工神经网络

List exercise = new ArrayList();//构造训练集

//人工模拟1000条训练数据 ,分界线为x2=x1+0.5

for(int i=0;i<1000000;i++)

{

Random rd = new Random();

double x1 = rd.nextDouble();//随机产生一个分量

double x2 = rd.nextDouble();//随机产生另一个分量

double[] da = {x1,x2};//产生数据向量

Data d = new Data(da, 2, x2 > x1+0.5 ? 1 : -1);//构造数据

exercise.add(d);//将训练数据加入训练集

}

int sum1 = 0;//记录类型1的训练记录数

int sum2 = 0;//记录类型-1的训练记录数

for(int i = 0; i < exercise.size(); i++)

{

if(exercise.get(i).getType()==1)

sum1++;

else if(exercise.get(i).getType()==-1)

sum2++;

}

double[] x1 = new double[sum1];

double[] y1 = new double[sum1];

double[] x2 = new double[sum2];

double[] y2 = new double[sum2];

int index1 = 0;

int index2 = 0;

for(int i = 0; i < exercise.size(); i++)

{

if(exercise.get(i).getType()==1)

{

x1[index1] = exercise.get(i).vector[0];

y1[index1++] = exercise.get(i).vector[1];

}

else if(exercise.get(i).getType()==-1)

{

x2[index2] = exercise.get(i).vector[0];

y2[index2++] = exercise.get(i).vector[1];

}

}

ann2.add("1", x1, y1);

ann2.add("-1", x2, y2);

ann2.draw();

ann2.input(exercise);//将训练集输入人工神经网络

ann2.fit();//训练

ann2.showWeigths();//显示权重向量

//人工生成一千条测试数据

for(int i=0;i<10000;i++)

{

Random rd = new Random();

double x1_ = rd.nextDouble();

double x2_ = rd.nextDouble();

double[] da = {x1_,x2_};

Dathttp://a test = new Data(da, 2, x2_ > x1_+0.5 ? 1 : -1);

ann2.predict(test);//测试

}

System.out.println("总共测试" + ann2.testSum + "条数据,有" + ann2.error + "条错误,错误率:" + ann2.error * 1.0 /ann2.testSum * 100 + "%");

}

/**

*

* @param eta 学习率

* @param n_iter 权重分量学习次数

*/

public ANN2(double eta, int n_iter) {

this.eta = eta;

this.n_iter = n_iter;

}

/**

* 输入训练集到人工神经网络

* @param exercise

*/

private void input(List exercise) {

this.exercLrEFbseKise = exercise;//保存训练集

weights = new double[exercise.get(0).dimention + 1];//初始化权重向量,其长度为训练数据维度+1

weights[0] = w0;//权重向量第一个分量为w0

for(int i = 1; i < weights.length; i++)

weights[i] = 0;//其余分量初始化为0

}

private void fit() {

for(int i = 0; i < n_iter; i++)//权重分量调整n_iter次

{

for(int j = 0; j < exercise.size(); j++)//对于训练集中的每条数据进行训练

{

int real_result = exercise.get(j).type;//y

int calculate_result = CalculateResult(exercise.get(j));//y'

double delta0 = eta * (real_result - calculate_result);//计算阈值更新

w0 += delta0;//阈值更新

weights[0] = w0;//更新w[0]

for(int k = 0; k < exercise.get(j).getDimention(); k++)//更新权重向量其它分量

{

double delta = eta * (real_result - calculate_result) * exercise.get(j).vector[k];

//w=*(y-y')*X

weights[k+1] += delta;

//w=w+w

}

}

}

}

private int CalculateResult(Data data) {

double z = w0 * x0;

for(int i = 0; i < data.dimention; i++)

z += data.vector[i] * weights[i+1];

//z=w0x0+w1x1+...+WmXm

//激活函数

if(z>=0)

return 1;

else

return -1;

}

private void showWeigths()

{

for(double w : weights)

System.out.println(w);

}

private void predict(Data data) {

int type = CalculateResult(data);

if(type == data.getType())

{

//System.out.println("预测正确");

}

else

{

//System.out.println("预测错误");

error ++;

}

testSum ++;

}

}

运行结果:

-0.22000000000000017

-0.4416843982815453

0.442444202054685

总共测试10000条数据,有17条错误,错误率:0.16999999999999998%

更多关于java算法相关内容感兴趣的读者可查看本站专题:《Java数据结构与算法教程》、《Java操作DOM节点技巧总结》、《Java文件与目录操作技巧汇总》和《Java缓存操作技巧汇总》

希望本文所述对大家java程序设计有所帮助。


版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:浅谈Java BitSet使用场景和代码示例
下一篇:JDBC连接mysql乱码异常问题处理总结
相关文章

 发表评论

暂时没有评论,来抢沙发吧~