基于Hadoop实现Knn算法

网友投稿 264 2023-01-16


基于Hadoop实现Knn算法

Knn算法的核心思想是如果一个样本在特征空间中的K个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。Knn方法在类别决策时,只与极少量的相邻样本有关。由于Knn方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,Knn方法较其他方法更为合适。

Knn算法流程如下:

1. 计算当前测试数据与训练数据中的每条数据的距离

2. 圈定距离最近的K个训练对象,作为测试对象的近邻

3. 计算这K个训练对象中出现最多的那个类别,并将这个类别作为当前测试数据的类别

以上流程是Knn的大致流程,按照这个流程实现的MR效率并不高,可以在这之上进行优化。在这里只写,跟着这个流程走的MR实现过程。

Mapper的设计:

由于测试数据相比于训练数据来说,会小很多,因此将测试数据用java API读取,放到内存中。所以,在setup中需要对测试数据进行初始化。在map中,计算当前测试数据与每条训练数据的距离,Mapper的值类型为:。map输出键类型为IntWritable,存放当前测试数据的下标,输出值类型为MyWritable,这是自定义值类型,其中存放的是距离以及与测试数据比较的训练数据的类别。

public class KnnMapper extends Mapper {

Logger log = LoggerFactory.getLogger(KnnMapper.class);

private List testData;

@Override

protected void setup(Context context)

throws IOException, InterruptedException {

// TODO Auto-generated method stub

Configuration conf= context.getConfiguration();

conf.set("fs.defaultFS", "master:8020");

String testPath= conf.get("TestFilePath");

Path testDataPath= nehttp://w Path(testPath);

FileSystem fs = FileSystem.get(conf);

this.testData = readTestData(fs,testDataPath);

}

@Override

protected void map(Object key, Text value, Context context)

throws IOException, InterruptedException {

// TODO Auto-generated method stub

String[] line = value.toString().split(",");

float[] trainData = new float[line.length-1];

for(int i=0;i

trainData[i] = Float.valueOf(line[i]);

log.info("训练数据:"+line[i]+"类别:"+line[line.length-1]);

}

for(int i=0; i< this.testData.size();i++){

float[] testI = this.testData.get(i);

float distance = Outh(testI, trainData);

log.info("距离:"+distance);

context.write(new IntWritable(i), new MyWritable(distance, line[line.length-1]));

}

}

private List readTestData(FileSystem fs,Path Path) throws IOException {

//补充代码完整

FSDataInputStream data = fs.open(Path);

BufferedReader bf = new BufferedReader(new InputStreamReader(data));

String line = "";

List list = new ArrayList<>();

while ((line = bf.readLine()) != null) {

String[] items = line.split(",");

float[] item = new float[items.length];

for(int i=0;i

item[i] = Float.valueOf(items[i]);

}

list.add(item);

}

return list;

}

// 计算欧式距离

private static float Outh(float[] testData, float[] inData) {

float distance =0.0f;

for(int i=0;i

distance += (testData[i]-inData[i])*(testData[i]-inData[i]);

}

distance = (float)Math.sqrt(distance);

return distance;

}

}

自定义值类型MyWritable如下:

public class MyWritable implements Writable{

private float distance;

private String label;

public MyWritable() {

// TODO Auto-generated constructor stub

}

public MyWritable(float distance, String label){

this.distance = distance;

this.label = label;

}

@Override

public String toString() {

// TODO Auto-generated method stub

return this.distance+","+this.label;

}

@Override

public void write(DataOutput out) throws IOException {

// TODO Auto-generated method stub

out.writeFloat(distance);

out.writeUTF(label);

}

@Override

public void readFields(DataInput in) throws IOException {

// TODO Auto-generated method stub

this.distance = in.readFloat();

this.label = in.readUTF();

}

public float getDistance() {

return distance;

}

public void setDistance(float distance) {

this.distance = distance;

}

public String getLabel() {

return label;

}

public void setLabel(String label) {

this.label = label;

}

}

在Reducer端中,需要初始化参数K,也就是圈定距离最近的K个对象的K值。在reduce中需要对距离按照从小到大的距离排序,然后选取前K条数据,再计算这K条数据中,出现次数最多的那个类别并将这个类别与测试数据的下标相对应并以K,V的形式输出到HDFS上。

public class KnnReducer extends Reducer {

private int K;

@Override

protected void setup(Context context)

throws IOException, InterruptedException {

// TODO Auto-generated method stub

this.K = context.getConfiguration().getInt("K", 5);

}

@Override

/***

* key => 0

* values =>([1,lable1],[2,lable2],[3,label2],[2.5,lable2])

*/

protected void reduce(IntWritable key, Iterable values,

Context context) throws IOException, InterruptedException {

// TODO Auto-generated method stub

MyWritable[] mywrit = new MyWritable[K];

for(int i=0;i

mywrit[i] = new MyWritable(Float.MAX_VALUE, "-1");

}

// 找出距离最小的前k个

for (MyWritable m : values) {

float distance = m.getDistance();

String label = m.getLabel();

for(MyWritable m1: mywrit){

if (distance < m1.getDistance()){

m1.setDistance(distance);

m1.setLabel(label);

}

}

}

// 找出前k个中,出现次数最多的类别

String[] testClass = new String[K];

for(int i=0;i

testClass[i] = mywrit[i].getLabel();

}

String countMost = mostEle(testClass);

context.write(key, new Text(countMost));

}

public static String mostEle(String[] strArray) {

HashMap map = new HashMap<>();

for (int i = 0; i < strArray.length; i++) {

String str = strArray[i];

if (map.containsKey(str)) {

int tmp = map.get(str);

map.put(str, tmp+1);

}else{

map.put(str, 1);

}

}

// 得到hashmap中值最大的键,也就是出现次数最多的类别

Collection count = map.values();

int maxCount = Collections.max(count);

String maxString = "";

for(Map.Entry entry: map.entrySet()){

if (maxCount == entry.getValue()) {

maxString = entry.getKey();

}

}

return maxString;

}

}

最后输出结果如下:

trainData[i] = Float.valueOf(line[i]);

log.info("训练数据:"+line[i]+"类别:"+line[line.length-1]);

}

for(int i=0; i< this.testData.size();i++){

float[] testI = this.testData.get(i);

float distance = Outh(testI, trainData);

log.info("距离:"+distance);

context.write(new IntWritable(i), new MyWritable(distance, line[line.length-1]));

}

}

private List readTestData(FileSystem fs,Path Path) throws IOException {

//补充代码完整

FSDataInputStream data = fs.open(Path);

BufferedReader bf = new BufferedReader(new InputStreamReader(data));

String line = "";

List list = new ArrayList<>();

while ((line = bf.readLine()) != null) {

String[] items = line.split(",");

float[] item = new float[items.length];

for(int i=0;i

item[i] = Float.valueOf(items[i]);

}

list.add(item);

}

return list;

}

// 计算欧式距离

private static float Outh(float[] testData, float[] inData) {

float distance =0.0f;

for(int i=0;i

distance += (testData[i]-inData[i])*(testData[i]-inData[i]);

}

distance = (float)Math.sqrt(distance);

return distance;

}

}

自定义值类型MyWritable如下:

public class MyWritable implements Writable{

private float distance;

private String label;

public MyWritable() {

// TODO Auto-generated constructor stub

}

public MyWritable(float distance, String label){

this.distance = distance;

this.label = label;

}

@Override

public String toString() {

// TODO Auto-generated method stub

return this.distance+","+this.label;

}

@Override

public void write(DataOutput out) throws IOException {

// TODO Auto-generated method stub

out.writeFloat(distance);

out.writeUTF(label);

}

@Override

public void readFields(DataInput in) throws IOException {

// TODO Auto-generated method stub

this.distance = in.readFloat();

this.label = in.readUTF();

}

public float getDistance() {

return distance;

}

public void setDistance(float distance) {

this.distance = distance;

}

public String getLabel() {

return label;

}

public void setLabel(String label) {

this.label = label;

}

}

在Reducer端中,需要初始化参数K,也就是圈定距离最近的K个对象的K值。在reduce中需要对距离按照从小到大的距离排序,然后选取前K条数据,再计算这K条数据中,出现次数最多的那个类别并将这个类别与测试数据的下标相对应并以K,V的形式输出到HDFS上。

public class KnnReducer extends Reducer {

private int K;

@Override

protected void setup(Context context)

throws IOException, InterruptedException {

// TODO Auto-generated method stub

this.K = context.getConfiguration().getInt("K", 5);

}

@Override

/***

* key => 0

* values =>([1,lable1],[2,lable2],[3,label2],[2.5,lable2])

*/

protected void reduce(IntWritable key, Iterable values,

Context context) throws IOException, InterruptedException {

// TODO Auto-generated method stub

MyWritable[] mywrit = new MyWritable[K];

for(int i=0;i

mywrit[i] = new MyWritable(Float.MAX_VALUE, "-1");

}

// 找出距离最小的前k个

for (MyWritable m : values) {

float distance = m.getDistance();

String label = m.getLabel();

for(MyWritable m1: mywrit){

if (distance < m1.getDistance()){

m1.setDistance(distance);

m1.setLabel(label);

}

}

}

// 找出前k个中,出现次数最多的类别

String[] testClass = new String[K];

for(int i=0;i

testClass[i] = mywrit[i].getLabel();

}

String countMost = mostEle(testClass);

context.write(key, new Text(countMost));

}

public static String mostEle(String[] strArray) {

HashMap map = new HashMap<>();

for (int i = 0; i < strArray.length; i++) {

String str = strArray[i];

if (map.containsKey(str)) {

int tmp = map.get(str);

map.put(str, tmp+1);

}else{

map.put(str, 1);

}

}

// 得到hashmap中值最大的键,也就是出现次数最多的类别

Collection count = map.values();

int maxCount = Collections.max(count);

String maxString = "";

for(Map.Entry entry: map.entrySet()){

if (maxCount == entry.getValue()) {

maxString = entry.getKey();

}

}

return maxString;

}

}

最后输出结果如下:

item[i] = Float.valueOf(items[i]);

}

list.add(item);

}

return list;

}

// 计算欧式距离

private static float Outh(float[] testData, float[] inData) {

float distance =0.0f;

for(int i=0;i

distance += (testData[i]-inData[i])*(testData[i]-inData[i]);

}

distance = (float)Math.sqrt(distance);

return distance;

}

}

自定义值类型MyWritable如下:

public class MyWritable implements Writable{

private float distance;

private String label;

public MyWritable() {

// TODO Auto-generated constructor stub

}

public MyWritable(float distance, String label){

this.distance = distance;

this.label = label;

}

@Override

public String toString() {

// TODO Auto-generated method stub

return this.distance+","+this.label;

}

@Override

public void write(DataOutput out) throws IOException {

// TODO Auto-generated method stub

out.writeFloat(distance);

out.writeUTF(label);

}

@Override

public void readFields(DataInput in) throws IOException {

// TODO Auto-generated method stub

this.distance = in.readFloat();

this.label = in.readUTF();

}

public float getDistance() {

return distance;

}

public void setDistance(float distance) {

this.distance = distance;

}

public String getLabel() {

return label;

}

public void setLabel(String label) {

this.label = label;

}

}

在Reducer端中,需要初始化参数K,也就是圈定距离最近的K个对象的K值。在reduce中需要对距离按照从小到大的距离排序,然后选取前K条数据,再计算这K条数据中,出现次数最多的那个类别并将这个类别与测试数据的下标相对应并以K,V的形式输出到HDFS上。

public class KnnReducer extends Reducer {

private int K;

@Override

protected void setup(Context context)

throws IOException, InterruptedException {

// TODO Auto-generated method stub

this.K = context.getConfiguration().getInt("K", 5);

}

@Override

/***

* key => 0

* values =>([1,lable1],[2,lable2],[3,label2],[2.5,lable2])

*/

protected void reduce(IntWritable key, Iterable values,

Context context) throws IOException, InterruptedException {

// TODO Auto-generated method stub

MyWritable[] mywrit = new MyWritable[K];

for(int i=0;i

mywrit[i] = new MyWritable(Float.MAX_VALUE, "-1");

}

// 找出距离最小的前k个

for (MyWritable m : values) {

float distance = m.getDistance();

String label = m.getLabel();

for(MyWritable m1: mywrit){

if (distance < m1.getDistance()){

m1.setDistance(distance);

m1.setLabel(label);

}

}

}

// 找出前k个中,出现次数最多的类别

String[] testClass = new String[K];

for(int i=0;i

testClass[i] = mywrit[i].getLabel();

}

String countMost = mostEle(testClass);

context.write(key, new Text(countMost));

}

public static String mostEle(String[] strArray) {

HashMap map = new HashMap<>();

for (int i = 0; i < strArray.length; i++) {

String str = strArray[i];

if (map.containsKey(str)) {

int tmp = map.get(str);

map.put(str, tmp+1);

}else{

map.put(str, 1);

}

}

// 得到hashmap中值最大的键,也就是出现次数最多的类别

Collection count = map.values();

int maxCount = Collections.max(count);

String maxString = "";

for(Map.Entry entry: map.entrySet()){

if (maxCount == entry.getValue()) {

maxString = entry.getKey();

}

}

return maxString;

}

}

最后输出结果如下:

distance += (testData[i]-inData[i])*(testData[i]-inData[i]);

}

distance = (float)Math.sqrt(distance);

return distance;

}

}

自定义值类型MyWritable如下:

public class MyWritable implements Writable{

private float distance;

private String label;

public MyWritable() {

// TODO Auto-generated constructor stub

}

public MyWritable(float distance, String label){

this.distance = distance;

this.label = label;

}

@Override

public String toString() {

// TODO Auto-generated method stub

return this.distance+","+this.label;

}

@Override

public void write(DataOutput out) throws IOException {

// TODO Auto-generated method stub

out.writeFloat(distance);

out.writeUTF(label);

}

@Override

public void readFields(DataInput in) throws IOException {

// TODO Auto-generated method stub

this.distance = in.readFloat();

this.label = in.readUTF();

}

public float getDistance() {

return distance;

}

public void setDistance(float distance) {

this.distance = distance;

}

public String getLabel() {

return label;

}

public void setLabel(String label) {

this.label = label;

}

}

在Reducer端中,需要初始化参数K,也就是圈定距离最近的K个对象的K值。在reduce中需要对距离按照从小到大的距离排序,然后选取前K条数据,再计算这K条数据中,出现次数最多的那个类别并将这个类别与测试数据的下标相对应并以K,V的形式输出到HDFS上。

public class KnnReducer extends Reducer {

private int K;

@Override

protected void setup(Context context)

throws IOException, InterruptedException {

// TODO Auto-generated method stub

this.K = context.getConfiguration().getInt("K", 5);

}

@Override

/***

* key => 0

* values =>([1,lable1],[2,lable2],[3,label2],[2.5,lable2])

*/

protected void reduce(IntWritable key, Iterable values,

Context context) throws IOException, InterruptedException {

// TODO Auto-generated method stub

MyWritable[] mywrit = new MyWritable[K];

for(int i=0;i

mywrit[i] = new MyWritable(Float.MAX_VALUE, "-1");

}

// 找出距离最小的前k个

for (MyWritable m : values) {

float distance = m.getDistance();

String label = m.getLabel();

for(MyWritable m1: mywrit){

if (distance < m1.getDistance()){

m1.setDistance(distance);

m1.setLabel(label);

}

}

}

// 找出前k个中,出现次数最多的类别

String[] testClass = new String[K];

for(int i=0;i

testClass[i] = mywrit[i].getLabel();

}

String countMost = mostEle(testClass);

context.write(key, new Text(countMost));

}

public static String mostEle(String[] strArray) {

HashMap map = new HashMap<>();

for (int i = 0; i < strArray.length; i++) {

String str = strArray[i];

if (map.containsKey(str)) {

int tmp = map.get(str);

map.put(str, tmp+1);

}else{

map.put(str, 1);

}

}

// 得到hashmap中值最大的键,也就是出现次数最多的类别

Collection count = map.values();

int maxCount = Collections.max(count);

String maxString = "";

for(Map.Entry entry: map.entrySet()){

if (maxCount == entry.getValue()) {

maxString = entry.getKey();

}

}

return maxString;

}

}

最后输出结果如下:

mywrit[i] = new MyWritable(Float.MAX_VALUE, "-1");

}

// 找出距离最小的前k个

for (MyWritable m : values) {

float distance = m.getDistance();

String label = m.getLabel();

for(MyWritable m1: mywrit){

if (distance < m1.getDistance()){

m1.setDistance(distance);

m1.setLabel(label);

}

}

}

// 找出前k个中,出现次数最多的类别

String[] testClass = new String[K];

for(int i=0;i

testClass[i] = mywrit[i].getLabel();

}

String countMost = mostEle(testClass);

context.write(key, new Text(countMost));

}

public static String mostEle(String[] strArray) {

HashMap map = new HashMap<>();

for (int i = 0; i < strArray.length; i++) {

String str = strArray[i];

if (map.containsKey(str)) {

int tmp = map.get(str);

map.put(str, tmp+1);

}else{

map.put(str, 1);

}

}

// 得到hashmap中值最大的键,也就是出现次数最多的类别

Collection count = map.values();

int maxCount = Collections.max(count);

String maxString = "";

for(Map.Entry entry: map.entrySet()){

if (maxCount == entry.getValue()) {

maxString = entry.getKey();

}

}

return maxString;

}

}

最后输出结果如下:

testClass[i] = mywrit[i].getLabel();

}

String countMost = mostEle(testClass);

context.write(key, new Text(countMost));

}

public static String mostEle(String[] strArray) {

HashMap map = new HashMap<>();

for (int i = 0; i < strArray.length; i++) {

String str = strArray[i];

if (map.containsKey(str)) {

int tmp = map.get(str);

map.put(str, tmp+1);

}else{

map.put(str, 1);

}

}

// 得到hashmap中值最大的键,也就是出现次数最多的类别

Collection count = map.values();

int maxCount = Collections.max(count);

String maxString = "";

for(Map.Entry entry: map.entrySet()){

if (maxCount == entry.getValue()) {

maxString = entry.getKey();

}

}

return maxString;

}

}

最后输出结果如下:


版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:关于网址post测试json的信息
下一篇:研发管理平台公司简介(研发管理平台公司简介范文)
相关文章

 发表评论

暂时没有评论,来抢沙发吧~