1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > opencv3/C++ 机器学习-EM算法/Expectation Maximization

opencv3/C++ 机器学习-EM算法/Expectation Maximization

时间:2022-11-01 10:53:28

相关推荐

opencv3/C++ 机器学习-EM算法/Expectation Maximization

EM算法/Expectation Maximization

EM算法包含两步:E,求期望(Expectation),利用概率模型参数的现有估计值,计算隐藏变量的期望;M,求极大(Maximization),利用E 步上求得的隐藏变量的期望,对参数模型进行最大似然估计。所得参数估计值用于下个E步的计算,重复至收敛。

期望最大化/EM算法以具有一定数量混合物的高斯混合分布的形式估计多变量概率密度函数的参数。

考虑从高斯混合模型画出的d维欧几里得空间中的N个特征向量{x1,x2,...,xNx1,x2,...,xN}的集合:

p(x;ak;Sk;πk)=∑mk=1πkpk(x),πk⩾0,∑mk=1πk=1p(x;ak;Sk;πk)=∑k=1mπkpk(x),πk⩾0,∑k=1mπk=1

pk(x)=φ(x;ak,Sk)=1(2π)d/2|Sk|1/2exp{−12(x−ak)TS−1k(x−ak)}pk(x)=φ(x;ak,Sk)=1(2π)d/2|Sk|1/2exp{−12(x−ak)TSk−1(x−ak)}

其中m是高斯混合模型的数量,pkpk是具有均值akak和协方差矩阵SkSk的正态分布密度,πkπk是第k个高斯混合模型的权重。 给定高斯混合模型个数M和样本xi,i=1...Nxi,i=1...N,算法找到所有高斯混合模型参数的最大似然估计(MLE),即ak,Skak,Sk和πkπk:

L(x,θ)=logp(x,θ)=∑ni=1log(∑mk=1πkpk(x))→maxθ∈Θ,L(x,θ)=logp(x,θ)=∑i=1nlog(∑k=1mπkpk(x))→maxθ∈Θ,

Θ=Θ=

{(ak,Sk,πk):ak∈Rd,Sk=STk>0,Sk∈Rd×d,πk≥0,∑mk=1πk=1}{(ak,Sk,πk):ak∈Rd,Sk=SkT>0,Sk∈Rd×d,πk≥0,∑k=1mπk=1}

EM算法是一个迭代过程。 每次迭代包括两个步骤。 在第一步E步即预期步骤中,使用当前可用的混合参数估计值,可以找出样本i属于混合模型k的概率pikpik(在下面的公式中表示为αikαik):

aki=πkφ(x;ak,Sk)∑mj=1πjφ(x;aj,Sj)aki=πkφ(x;ak,Sk)∑j=1mπjφ(x;aj,Sj)

在第二步M步即最大化步骤中,使用计算出的概率对高斯混合模型的参数估计值进行细化:

πk=1N∑Ni=1akiπk=1N∑i=1Naki;

πk=∑Ni=1akixi∑Ni=1akiπk=∑i=1Nakixi∑i=1Naki ;

Sk=∑Ni=1aki(xi−ak)(xi−ak)T∑Ni=1akiSk=∑i=1Naki(xi−ak)(xi−ak)T∑i=1Naki

或者,当提供pikpik的初始值时,该算法可以从M步开始。 当pikpik未知时的另一种选择是使用更简单的聚类算法对输入采样进行预先聚类,从而获得初始的pikpik(通常用k-means算法实现)。

EM算法的一个主要问题是需要估计大量参数。 大多数参数存在于协方差矩阵中,这些矩阵大小为d×d,其中d是特征空间维度。 但在许多实际问题中,协方差矩阵接近于对角线或者甚至接近μk∗Iμk∗I,其中II是单位矩阵,μk" role="presentation">μk是混合相关的“比例”参数。 因此,一个健壮的计算方案是对协方差矩阵加较强的约束,然后用估计的参数作为较少约束优化问题的输入(通常对角协方差矩阵已经足够了)。

OpenCV EM类

相关函数

virtual void setClustersNumber(int val);

高斯混合模型中混合成分的数量。默认值是EM :: DEFAULT_NCLUSTERS = 5。

virtual void setCovarianceMatrixType(int val);

协方差矩阵的类型。协方差矩阵的约束定义了其类型。

COV_MAT_SPHERICAL= 0:缩放的单位矩阵 μk∗Iμk∗I。 对每个矩阵估计唯一的参数μkμk。用于约束条件相关时或作为优化的第一步(例如数据用PCA预处理时)。

COV_MAT_DIAGONAL= 1:具有正对角元素的对角矩阵。 每个矩阵d个自由参数。 (常选项,估算结果良好)

COV_MAT_GENERIC= 2:对称正定矩阵。 每个矩阵中的自由参数大约d2/2d2/2个。 不建议使用此选项,除非对参数或大量训练样本有相当准确的初始估计。

virtual bool trainEM(InputArray samples,OutputArray logLikelihoods=noArray(),OutputArray labels=noArray(),OutputArray probs=noArray()) ;

估计样本集中的高斯混合模型参数。

这种变化开始于Expectation步。模型参数的初始值通过k-means估计。与许多ML模型不同,EM是无监督学习算法,因此训练时不用输入类标签。通过样本数据计算高斯混合参数的最大似然估计,将结构中的所有参数进行存储:pi,kpi,k存概率, akak 存均值,SkSk存covs [k],πkπk存权重,并且可选地为每个样本计算输出“类别标签”:labelsi=arg maxk(pi,k),i=1..Nlabelsi=arg maxk(pi,k),i=1..N(每个样本的最可能的模型分量的索引)。训练好的模型可以用于预测。

samples ::样本。单通道矩阵,每一行为一个样本。若矩阵不是CV_64F类型,则将被转换为此类型的内部矩阵。

logLikelihoods ::可选输出矩阵,包含每个样本的似然对数值。大小nsamples×1nsamples×1,类型CV_64FC1。

labels ::每个样本的输出“类别标签”:labelsi=arg maxk(pi,k),i=1..Nlabelsi=arg maxk(pi,k),i=1..N每个样本最可能的高斯混合模型分量)。大小nsamples×1nsamples×1 ,类型CV_32SC1。

probs ::可选输出矩阵,包含每个给定样本的各个高斯混合模型分量的后验概率。大小 nsamples×nclustersnsamples×nclusters ,类型CV_64FC1。

应用示例

图像分割

使用EM算法对图像进行分割。

#include <opencv2/opencv.hpp>#include <iostream>using namespace std;using namespace cv;using namespace cv::ml;int main(){Vec3b colors[] ={Vec3b(0, 0, 255), Vec3b(0, 255, 0), Vec3b(255, 100, 100), Vec3b(255, 0, 255)};Mat data, labels, src, dst;src = imread("E:/image/image/red.jpg", 1);resize(src, src, Size(src.cols/1.5,src.rows/1.5));if(src.empty()){printf("can not load image \n");return -1;}src.copyTo(dst);for (int i = 0; i < src.rows; i++)for (int j = 0; j < src.cols; j++){Vec3b point = src.at<Vec3b>(i, j);Mat tmp = (Mat_<float>(1, 3) << point[0], point[1], point[2]);data.push_back(tmp);}Ptr<EM> model = EM::create();model->setClustersNumber(4); //类个数model->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL);model->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 300, 0.1));model->trainEM(data, noArray(), labels, noArray());int n = 0;//显示结果,不同的类别用不同的颜色for (int i = 0; i < dst.rows; i++)for (int j = 0; j < dst.cols; j++){int index = labels.at<int>(n);dst.at<Vec3b>(i, j) = colors[index];n++;}imshow("src", src);imshow("dst", dst);waitKey(0);return 0;}

点坐标分类

从文件points.txt中读取点坐标以及对应的分类,然后使用EM算法对点所在区域进行划分。

#include <opencv2/opencv.hpp>#include "opencv2/ml.hpp"#include <iostream> #include <fstream> using namespace std;using namespace cv;using namespace cv::ml;//EM算法int main(){Mat src, dst;vector<Point> trainedPoints;vector<int> trainedPointsMarkers;//读取文件中的点坐标FILE *fp;int flge = 0;int fpoint,flabel;fp = fopen("E:\\points.txt", "r+");if (fp == NULL){printf("Cannot open the file!\n");exit(0);}Point point;while (!feof(fp)){ fscanf(fp, "%d", &fpoint);if (feof(fp)) break;//依次为横坐标、纵坐标、分类if ((flge%3==0? point.x = fpoint: flge%3==1? point.y = fpoint:flge%3==2? flabel = fpoint : -1)<0) return -1;if (flge%3==2){trainedPoints.push_back(point);trainedPointsMarkers.push_back(flabel);}flge++;}vector<Vec3b> colors(4);colors[0] = Vec3b(0, 255, 0);colors[1] = Vec3b(0, 0, 255);colors[2] = Vec3b(0, 255, 255);colors[3] = Vec3b(255, 0, 0);src.create( 480, 640, CV_8UC3 );src = Scalar::all(0);// 绘制点for( size_t i = 0; i < trainedPoints.size(); i++ ){Scalar c = colors[trainedPointsMarkers[i]];circle( src, trainedPoints[i], 3, c, -1 );}src.copyTo(dst);imshow( "points", src );Mat samples;Mat(trainedPoints).reshape(1, (int)trainedPoints.size()).convertTo(samples, CV_32F);int nmodels = (int)colors.size();vector<Ptr<EM> > em_models(nmodels);Mat modelSamples;for( int i = 0; i < nmodels; i++ ){modelSamples.release();for( int j = 0; j < samples.rows; j++ ){if( trainedPointsMarkers[j] == i )modelSamples.push_back(samples.row(j));}// 训练模型if( !modelSamples.empty() ){const int componentCount = 5;Ptr<EM> em = EM::create();//高斯混合模型中混合成分的数量em->setClustersNumber(componentCount);//协方差矩阵的类型。em->setCovarianceMatrixType(EM::COV_MAT_DIAGONAL);//训练模型em->trainEM(modelSamples, noArray(), noArray(), noArray());em_models[i] = em;}}Mat testSample(1, 2, CV_32FC1 );Mat logLikelihoods(1, nmodels, CV_64FC1, Scalar(-DBL_MAX));for( int y = 0; y < src.rows; y += 3 ){for( int x = 0; x < src.cols; x += 3 ){testSample.at<float>(0) = (float)x;testSample.at<float>(1) = (float)y;for( int i = 0; i < nmodels; i++ ){if( !em_models[i].empty() )logLikelihoods.at<double>(i) = em_models[i]->predict2(testSample, noArray())[0];}Point maxLoc;minMaxLoc(logLikelihoods, 0, 0, 0, &maxLoc);dst.at<Vec3b>(y, x) = colors[maxLoc.x];}}imshow( "EM", dst );waitKey();return 0;}

文件points.txt中的内容为:

(依次为横坐标、纵坐标、分类)

281 234 0265 227 0261 204 0273 185 0298 171 0326 178 0328 206 0330 226 0323 245 0300 256 0280 259 0262 257 0245 240 0238 229 0236 206 0239 184 0265 158 0279 154 0297 153 0325 153 0338 168 0352 206 0353 229 0305 211 0305 237 0240 308 3219 300 3200 281 3189 237 3184 196 3191 163 3214 140 3246 128 3288 127 3337 128 3265 122 3306 123 3287 113 3322 114 3349 118 3368 135 3389 174 3399 197 3399 233 3388 261 3365 295 3344 310 3280 319 3265 319 3309 316 3262 333 3223 320 3193 303 3180 290 3175 262 3165 228 3169 183 3177 143 3186 127 3205 115 3260 104 3235 97 3293 97 3357 102 3392 119 3408 150 3396 149 3412 200 3419 247 3405 291 3366 320 3303 332 3341 333 3330 105 3273 92 3280 209 0333 256 0298 271 0165 163 3157 205 3159 246 3159 276 3171 309 3189 327 3207 333 3239 340 3284 352 3339 351 3371 334 3382 311 3394 282 3423 269 3425 218 3424 179 3411 157 3340 85 3308 85 3221 118 3151 175 3254 88 3281 79 3185 374 1171 359 1151 342 1137 323 1122 298 1116 272 1116 233 1116 196 1119 157 1119 139 1137 103 1145 92 1157 73 1181 62 1218 55 1260 51 1317 49 1362 53 1393 65 1427 83 1454 112 1468 130 1490 168 1504 197 1511 219 1516 235 1539 276 1563 335 1580 373 1593 404 1605 425 1616 444 1628 466 1222 382 1259 384 1306 389 1293 401 1274 402 1247 402 1210 402 1193 393 1159 380 1141 369 1123 339 1116 316 1105 294 195 270 192 244 189 217 188 188 189 160 190 111 193 94 181 132 1110 117 1116 73 1149 45 1165 36 1222 35 1290 27 1251 31 1351 40 1329 35 1284 40 1384 48 1410 55 1441 77 1461 102 1484 130 1510 171 1524 201 1535 232 1546 260 1562 297 1579 322 1595 351 1615 379 1632 406 1620 400 1473 230 2479 268 2477 299 2473 316 2467 332 2462 252 2462 303 2457 330 2450 346 2428 375 2395 397 2381 406 2320 436 2315 437 2283 447 2224 450 2190 445 2177 442 2140 428 2125 416 299 405 284 392 279 387 259 347 256 336 245 299 240 259 237 224 225 155 219 97 217 43 217 26 220 75 218 132 218 170 211 202 213 233 218 265 225 307 232 342 244 375 251 389 259 410 269 430 2117 459 2179 460 2110 442 2186 472 2286 476 2353 467 2377 457 2236 462 2318 459 2353 447 2386 427 2404 419 2448 381 2413 431 2428 399 2470 342 2447 401 2375 438 2369 421 2337 445 2265 464 2220 471 2182 470 2138 453 285 417 2119 429 2149 447 2157 465 2561 274 1553 317 1586 359 1586 300 1624 382 1623 318 1608 297 1577 240 1556 219 1550 205 1522 152 1619 325 1628 347 1627 279 1578 215 1546 189 1601 264 1587 231 1496 137 1492 104 1479 85 1464 66 1441 49 1423 42 1364 24 1345 21 1314 18 1262 17 1243 17 1202 22 1175 46 1152 61 1125 91 1101 151 199 172 1102 233 1104 246 176 159 183 196 180 235 187 258 190 277 197 307 1110 323 1136 360 1160 366 1185 383 1179 401 1221 401 1265 396 1230 390 1267 412 1309 404 1323 399 1341 390 1328 386 1463 366 2433 409 2428 431 2401 454 2358 468 2329 472 2294 464 2250 468 2224 476 2205 463 2249 446 2210 439 2161 427 2115 404 298 393 284 377 273 367 252 334 241 324 220 303 216 284 210 320 214 353 28 353 29 384 255 418 236 391 229 395 232 208 215 165 28 127 26 98 25 72 25 63 25 4 26 29 231 242 233 275 27 300 220 338 228 367 225 386 229 414 243 430 268 450 2102 469 297 457 292 439 2100 430 2135 473 2292 412 1244 414 1214 411 1253 411 1275 390 1207 377 1334 399 1296 416 1243 416 1199 406 1356 379 1374 374 1361 383 1353 391 1342 395 1333 401 1313 414 1245 382 1242 393 1230 408 1176 384 1154 392 1134 381 1121 359 1113 344 1291 384 1285 397 1285 407 1279 412 1107 211 1101 195 1444 362 2424 386 2350 423 2338 429 2304 444 2263 453 2405 394 2427 365 2443 334 2449 315 2458 291 2463 275 2478 284 2463 390 2447 420 2438 435 2416 458 2413 462 2392 474 2378 475 2

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。