001/******************************************************************************* 002 * This software is provided as a supplement to the authors' textbooks on digital 003 * image processing published by Springer-Verlag in various languages and editions. 004 * Permission to use and distribute this software is granted under the BSD 2-Clause 005 * "Simplified" License (see http://opensource.org/licenses/BSD-2-Clause). 006 * Copyright (c) 2006-2023 Wilhelm Burger, Mark J. Burge. All rights reserved. 007 * Visit https://imagingbook.com for additional details. 008 ******************************************************************************/ 009package imagingbook.common.geometry.fitting.points; 010 011import imagingbook.common.geometry.basic.Pnt2d; 012import imagingbook.common.geometry.basic.Pnt2d.PntDouble; 013import imagingbook.common.geometry.basic.Pnt2d.PntInt; 014import imagingbook.common.math.Matrix; 015import imagingbook.common.math.PrintPrecision; 016import org.apache.commons.math3.linear.ArrayRealVector; 017import org.apache.commons.math3.linear.LUDecomposition; 018import org.apache.commons.math3.linear.MatrixUtils; 019import org.apache.commons.math3.linear.RealMatrix; 020import org.apache.commons.math3.linear.RealVector; 021import org.apache.commons.math3.linear.SingularValueDecomposition; 022 023import static imagingbook.common.math.Arithmetic.sqr; 024 025/** 026 * <p> 027 * Implements a 2-dimensional Procrustes fit, using the algorithm described in [1]. Usage example: 028 * </p> 029 * <pre> 030 * Point[] P = ... // create sequence of 2D source points 031 * Point[] Q = ... // create sequence of 2D target points 032 * ProcrustesFit pf = new ProcrustesFit(P, Q); 033 * RealMatrix R = pf.getRotation(); 034 * RealVector t = pf.getTranslation(); 035 * double s = pf.getScale(); 036 * double err = pf.getError(); 037 * RealMatrix A = pf.getTransformationMatrix(); 038 * </pre> 039 * <p> 040 * [1] Shinji Umeyama, "Least-squares estimation of transformation parameters between two point patterns", IEEE 041 * Transactions on Pattern Analysis and Machine Intelligence, 13.4 (Apr. 1991), pp. 376–380. 042 * </p> 043 * 044 * @author WB 045 * @version 2021/11/27 046 */ 047public class ProcrustesFit2d implements LinearFit2d { 048 049 private final RealMatrix R; // orthogonal (rotation) matrix 050 private final RealVector t; // translation vector 051 private final double s; // uniform scale 052 053 private final RealMatrix A; // resulting transformation matrix 054 private final double err; // RMS fitting error 055 056 // -------------------------------------------------------------- 057 058 /** 059 * Convenience constructor, with parameters {@code allowTranslation}, {@code allowScaling} and {@code forceRotation} 060 * set to {@code true}. 061 * 062 * @param P the source points 063 * @param Q the target points 064 */ 065 public ProcrustesFit2d(Pnt2d[] P, Pnt2d[] Q) { 066 this(P, Q, true, true, true); 067 } 068 069 /** 070 * Full constructor. 071 * 072 * @param P the first point sequence 073 * @param Q the second point sequence 074 * @param allowTranslation if {@code true}, translation (t) between point sets is considered, otherwise zero 075 * translation is assumed 076 * @param allowScaling if {@code true}, scaling (s) between point sets is considered, otherwise unit scale assumed 077 * @param forceRotation if {@code true}, the orthogonal part of the transformation (Q) is forced to a true rotation 078 * and no reflection is allowed 079 */ 080 public ProcrustesFit2d(Pnt2d[] P, Pnt2d[] Q, boolean allowTranslation, boolean allowScaling, boolean forceRotation) { 081 checkSize(P, Q); 082 083 double[] meanP = null; 084 double[] meanY = null; 085 086 if (allowTranslation) { 087 meanP = getMeanVec(P); 088 meanY = getMeanVec(Q); 089 } 090 091 RealMatrix vP = makeDataMatrix(P, meanP); 092 RealMatrix vQ = makeDataMatrix(Q, meanY); 093 MatrixUtils.checkAdditionCompatible(vP, vQ); // P, Q of same dimensions? 094 095 RealMatrix QPt = vQ.multiply(vP.transpose()); 096 SingularValueDecomposition svd = new SingularValueDecomposition(QPt); 097 098 RealMatrix U = svd.getU(); 099 RealMatrix S = svd.getS(); 100 RealMatrix V = svd.getV(); 101 102 double d = (svd.getRank() >= 2) ? det(QPt) : det(U) * det(V); 103 104 RealMatrix D = MatrixUtils.createRealIdentityMatrix(2); 105 if (d < 0 && forceRotation) 106 D.setEntry(1, 1, -1); 107 108 R = U.multiply(D).multiply(V.transpose()); 109 110 double normP = vP.getFrobeniusNorm(); 111 double normQ = vQ.getFrobeniusNorm(); 112 113 s = (allowScaling) ? 114 S.multiply(D).getTrace() / sqr(normP) : 1.0; 115 116 if (allowTranslation) { 117 RealVector ma = MatrixUtils.createRealVector(meanP); 118 RealVector mb = MatrixUtils.createRealVector(meanY); 119 t = mb.subtract(R.scalarMultiply(s).operate(ma)); 120 } 121 else { 122 t = new ArrayRealVector(2); // zero vector 123 } 124 125 // make the transformation matrix A 126 RealMatrix cR = R.scalarMultiply(s); 127 A = MatrixUtils.createRealMatrix(2, 3); 128 A.setSubMatrix(cR.getData(), 0, 0); 129 A.setColumnVector(2, t); 130 131 // calculate the fitting error: 132 err = Math.sqrt(sqr(normQ) - sqr(S.multiply(D).getTrace() / normP)); 133 } 134 135 // ----------------------------------------------------------------- 136 137 /** 138 * Retrieves the estimated scale. 139 * @return The estimated scale (or 1 if {@code allowscaling = false}). 140 */ 141 public double getScale() { 142 return s; 143 } 144 145 /** 146 * Retrieves the estimated orthogonal (rotation) matrix. 147 * @return The estimated rotation matrix. 148 */ 149 public double[][] getRotation() { 150 return R.getData(); 151 } 152 153 /** 154 * Retrieves the estimated translation vector. 155 * @return The estimated translation vector. 156 */ 157 public double[] getTranslation() { 158 return t.toArray(); 159 } 160 161 // -------------------------------------------------------- 162 163 @Override 164 public double[][] getTransformationMatrix() { 165 return A.getData(); 166// return A; 167 } 168 169 @Override 170 public double getError() { 171 return err; 172 } 173 174 /** 175 * Calculates the total error for the estimated fit as the sum of the squared Euclidean distances between the 176 * transformed point set X and the reference set Y. This method is provided for testing as an alternative to the 177 * quicker {@link #getError()} method. 178 * 179 * @param P Sequence of n-dimensional points. 180 * @param Q Sequence of n-dimensional points (reference). 181 * @return The total error for the estimated fit. 182 */ 183 private double getEuclideanError(Pnt2d[] P, Pnt2d[] Q) { 184 int m = Math.min(P.length, Q.length); 185 RealMatrix sR = R.scalarMultiply(s); 186 double errSum = 0; 187 for (int i = 0; i < m; i++) { 188 RealVector p = new ArrayRealVector(P[i].toDoubleArray()); 189 RealVector q = new ArrayRealVector(Q[i].toDoubleArray()); 190 RealVector pp = sR.operate(p).add(t); 191 //System.out.format("p=%s, q=%s, pp=%s\n", p.toString(), q.toString(), pp.toString()); 192 double e = pp.subtract(q).getNorm(); 193 errSum = errSum + sqr(e); 194 } 195 return Math.sqrt(errSum); // correct! 196 } 197 198 // ----------------------------------------------------------------- 199 200 private double det(RealMatrix M) { 201 return new LUDecomposition(M).getDeterminant(); 202 } 203 204 private double[] getMeanVec(Pnt2d[] points) { 205 double sumX = 0; 206 double sumY = 0; 207 for (Pnt2d p : points) { 208 sumX = sumX + p.getX(); 209 sumY = sumY + p.getY(); 210 } 211 return new double[] {sumX / points.length, sumY / points.length}; 212 } 213 214 private RealMatrix makeDataMatrix(Pnt2d[] points, double[] meanX) { 215 RealMatrix M = MatrixUtils.createRealMatrix(2, points.length); 216 RealVector mean = MatrixUtils.createRealVector(meanX); 217 int i = 0; 218 for (Pnt2d p : points) { 219 RealVector cv = p.toRealVector(); 220// RealVector cv = MatrixUtils.createRealVector(p.toDoubleArray()); 221 if (meanX != null) { 222 cv = cv.subtract(mean); 223 } 224 M.setColumnVector(i, cv); 225 i++; 226 } 227 return M; 228 } 229 230// private void printSVD(SingularValueDecomposition svd) { 231// RealMatrix U = svd.getU(); 232// RealMatrix S = svd.getS(); 233// RealMatrix V = svd.getV(); 234// System.out.println("------ SVD ---------------"); 235// System.out.println("U = " + Matrix.toString(U.getData())); 236// System.out.println("S = " + Matrix.toString(S.getData())); 237// System.out.println("V = " + Matrix.toString(V.getData())); 238// System.out.println("--------------------------"); 239// } 240 241 242 private void checkSize(Pnt2d[] P, Pnt2d[] Q) { 243 if (P.length < 3 || Q.length < 3) { 244 throw new IllegalArgumentException("At least 3 point pairs are required to calculate this fit"); 245 } 246 } 247 248 249 250 private static double roundToDigits(double x, int ndigits) { 251 int d = (int) Math.pow(10, ndigits); 252 return Math.rint(x * d) / d; 253 } 254 255 // -------------------------------------------------------------------------------- 256 257 // public static void main(String[] args) { 258 // PrintPrecision.set(6); 259 // int NDIGITS = 1; 260 // 261 // boolean allowTranslation = true; 262 // boolean allowScaling = true; 263 // boolean forceRotation = true; 264 // 265 // double a = 0.6; 266 // double[][] R0data = 267 // {{ Math.cos(a), -Math.sin(a) }, 268 // { Math.sin(a), Math.cos(a) }}; 269 // 270 // RealMatrix R0 = MatrixUtils.createRealMatrix(R0data); 271 // double[] t0 = {4, -3}; 272 // double s = 3.5; 273 // 274 // System.out.format("original alpha: a = %.6f\n", a); 275 // System.out.println("original rotation: R = \n" + Matrix.toString(R0.getData())); 276 // System.out.println("original translation: t = " + Matrix.toString(t0)); 277 // System.out.format("original scale: s = %.6f\n", s); 278 // System.out.println(); 279 // 280 // Pnt2d[] P = { 281 // PntInt.from(2, 5), 282 // PntInt.from(7, 3), 283 // PntInt.from(0, 9), 284 // PntInt.from(5, 4) 285 // }; 286 // 287 // Pnt2d[] Q = new Pnt2d[P.length]; 288 // 289 // for (int i = 0; i < P.length; i++) { 290 // Pnt2d q = PntDouble.from(R0.operate(P[i].toDoubleArray())); 291 // // noise! 292 // double qx = roundToDigits(s * q.getX() + t0[0], NDIGITS); 293 // double qy = roundToDigits(s * q.getY() + t0[1], NDIGITS); 294 // Q[i] = Pnt2d.PntDouble.from(qx, qy); 295 // } 296 // 297 // //P[0] = Point.create(2, 0); // to provoke a large error 298 // 299 // ProcrustesFit2d pf = new ProcrustesFit2d(P, Q, allowTranslation, allowScaling, forceRotation); 300 // 301 // double[][] R = pf.getRotation(); 302 // System.out.format("estimated alpha: a = %.6f\n", Math.acos(R[0][0])); 303 // System.out.println("estimated rotation: R = \n" + Matrix.toString(R)); 304 // double[] T = pf.getTranslation(); 305 // System.out.println("estimated translation: t = " + Matrix.toString(T)); 306 // System.out.format("estimated scale: s = %.6f\n", pf.getScale()); 307 // 308 // System.out.println(); 309 // System.out.format("RMS fitting error = %.6f\n", pf.getError()); 310 // System.out.format("euclidean error (test) = %.6f\n", pf.getEuclideanError(P, Q)); 311 // 312 // double[][] A = pf.getTransformationMatrix(); 313 // System.out.println("transformation matrix: A = \n" + Matrix.toString(A)); 314 // } 315 316 /* 317 original alpha: a = 0.600000 318 original rotation: R = 319 {{0.825336, -0.564642}, 320 {0.564642, 0.825336}} 321 original translation: t = {4.000000, -3.000000} 322 original scale: s = 3.500000 323 324 estimated alpha: a = 0.599589 325 estimated rotation: R = 326 {{0.825568, -0.564303}, 327 {0.564303, 0.825568}} 328 estimated translation: t = {3.980905, -3.011055} 329 estimated scale: s = 3.500560 330 331 fitting error = 0.048079 332 euclidean error (test) = 0.048079 333 transformation matrix: A = 334 {{2.889950, -1.975377, 3.980905}, 335 {1.975377, 2.889950, -3.011055}} 336 */ 337} 338