001/*******************************************************************************
002 * This software is provided as a supplement to the authors' textbooks on digital
003 * image processing published by Springer-Verlag in various languages and editions.
004 * Permission to use and distribute this software is granted under the BSD 2-Clause
005 * "Simplified" License (see http://opensource.org/licenses/BSD-2-Clause).
006 * Copyright (c) 2006-2023 Wilhelm Burger, Mark J. Burge. All rights reserved.
007 * Visit https://imagingbook.com for additional details.
008 ******************************************************************************/
009package imagingbook.common.geometry.fitting.points;
010
011import imagingbook.common.geometry.basic.Pnt2d;
012import imagingbook.common.geometry.basic.Pnt2d.PntDouble;
013import imagingbook.common.geometry.basic.Pnt2d.PntInt;
014import imagingbook.common.math.Matrix;
015import imagingbook.common.math.PrintPrecision;
016import org.apache.commons.math3.linear.ArrayRealVector;
017import org.apache.commons.math3.linear.LUDecomposition;
018import org.apache.commons.math3.linear.MatrixUtils;
019import org.apache.commons.math3.linear.RealMatrix;
020import org.apache.commons.math3.linear.RealVector;
021import org.apache.commons.math3.linear.SingularValueDecomposition;
022
023import static imagingbook.common.math.Arithmetic.sqr;
024
025/**
026 * <p>
027 * Implements a 2-dimensional Procrustes fit, using the algorithm described in [1]. Usage example:
028 * </p>
029 * <pre>
030 * Point[] P = ... // create sequence of 2D source points
031 * Point[] Q = ... // create sequence of 2D target points
032 * ProcrustesFit pf = new ProcrustesFit(P, Q);
033 * RealMatrix R = pf.getRotation();
034 * RealVector t = pf.getTranslation();
035 * double s = pf.getScale();
036 * double err = pf.getError();
037 * RealMatrix A = pf.getTransformationMatrix();
038 * </pre>
039 * <p>
040 * [1] Shinji Umeyama, "Least-squares estimation of transformation parameters between two point patterns", IEEE
041 * Transactions on Pattern Analysis and Machine Intelligence, 13.4 (Apr. 1991), pp. 376–380.
042 * </p>
043 *
044 * @author WB
045 * @version 2021/11/27
046 */
047public class ProcrustesFit2d implements LinearFit2d {
048        
049        private final RealMatrix R;                                     // orthogonal (rotation) matrix
050        private final RealVector t;                                     // translation vector
051        private final double s;                                         // uniform scale
052        
053        private final RealMatrix A;                                     // resulting transformation matrix
054        private final double err;                                       // RMS fitting error
055        
056        // --------------------------------------------------------------
057
058        /**
059         * Convenience constructor, with parameters {@code allowTranslation}, {@code allowScaling} and {@code forceRotation}
060         * set to {@code true}.
061         *
062         * @param P the source points
063         * @param Q the target points
064         */
065        public ProcrustesFit2d(Pnt2d[] P, Pnt2d[] Q) {
066                this(P, Q, true, true, true);
067        }
068
069        /**
070         * Full constructor.
071         *
072         * @param P the first point sequence
073         * @param Q the second point sequence
074         * @param allowTranslation if {@code true}, translation (t) between point sets is considered, otherwise zero
075         * translation is assumed
076         * @param allowScaling if {@code true}, scaling (s) between point sets is considered, otherwise unit scale assumed
077         * @param forceRotation if {@code true}, the orthogonal part of the transformation (Q) is forced to a true rotation
078         * and no reflection is allowed
079         */
080        public ProcrustesFit2d(Pnt2d[] P, Pnt2d[] Q, boolean allowTranslation, boolean allowScaling, boolean forceRotation) {
081                checkSize(P, Q);
082                
083                double[] meanP = null;
084                double[] meanY = null;
085                
086                if (allowTranslation) {
087                        meanP = getMeanVec(P);
088                        meanY = getMeanVec(Q);
089                }
090                
091                RealMatrix vP = makeDataMatrix(P, meanP);
092                RealMatrix vQ = makeDataMatrix(Q, meanY);
093                MatrixUtils.checkAdditionCompatible(vP, vQ);    // P, Q of same dimensions?
094                
095                RealMatrix QPt = vQ.multiply(vP.transpose());
096                SingularValueDecomposition svd = new SingularValueDecomposition(QPt);
097                
098                RealMatrix U = svd.getU();
099                RealMatrix S = svd.getS();
100                RealMatrix V = svd.getV();
101                        
102                double d = (svd.getRank() >= 2) ? det(QPt) : det(U) * det(V);
103                
104                RealMatrix D = MatrixUtils.createRealIdentityMatrix(2);
105                if (d < 0 && forceRotation)
106                        D.setEntry(1, 1, -1);
107                
108                R = U.multiply(D).multiply(V.transpose());
109                
110                double normP = vP.getFrobeniusNorm();
111                double normQ = vQ.getFrobeniusNorm();
112                
113                s = (allowScaling) ? 
114                                S.multiply(D).getTrace() / sqr(normP) : 1.0;
115                
116                if (allowTranslation) {
117                        RealVector ma = MatrixUtils.createRealVector(meanP);
118                        RealVector mb = MatrixUtils.createRealVector(meanY);
119                        t = mb.subtract(R.scalarMultiply(s).operate(ma));
120                }
121                else {
122                        t = new ArrayRealVector(2);     // zero vector
123                }
124                
125                // make the transformation matrix A
126                RealMatrix cR = R.scalarMultiply(s);
127                A = MatrixUtils.createRealMatrix(2, 3);
128                A.setSubMatrix(cR.getData(), 0, 0);
129                A.setColumnVector(2, t);
130                
131                // calculate the fitting error:
132                err = Math.sqrt(sqr(normQ) - sqr(S.multiply(D).getTrace() / normP));    
133        }
134        
135        // -----------------------------------------------------------------
136        
137        /**
138         * Retrieves the estimated scale.
139         * @return The estimated scale (or 1 if {@code allowscaling = false}).
140         */
141        public double getScale() {
142                return s;
143        }
144        
145        /**
146         * Retrieves the estimated orthogonal (rotation) matrix.
147         * @return The estimated rotation matrix.
148         */
149        public double[][] getRotation() {
150                return R.getData();
151        }
152        
153        /**
154         * Retrieves the estimated translation vector.
155         * @return The estimated translation vector.
156         */
157        public double[] getTranslation() {
158                return t.toArray();
159        }
160        
161        // --------------------------------------------------------
162        
163        @Override
164        public double[][] getTransformationMatrix() {
165                return A.getData();
166//              return A;
167        }
168        
169        @Override
170        public double getError() {
171                return err;
172        }
173
174        /**
175         * Calculates the total error for the estimated fit as the sum of the squared Euclidean distances between the
176         * transformed point set X and the reference set Y. This method is provided for testing as an alternative to the
177         * quicker {@link #getError()} method.
178         *
179         * @param P Sequence of n-dimensional points.
180         * @param Q Sequence of n-dimensional points (reference).
181         * @return The total error for the estimated fit.
182         */
183        private double getEuclideanError(Pnt2d[] P, Pnt2d[] Q) {
184                int m = Math.min(P.length,  Q.length);
185                RealMatrix sR = R.scalarMultiply(s);
186                double errSum = 0;
187                for (int i = 0; i < m; i++) {
188                        RealVector p = new ArrayRealVector(P[i].toDoubleArray());
189                        RealVector q = new ArrayRealVector(Q[i].toDoubleArray());
190                        RealVector pp = sR.operate(p).add(t);
191                        //System.out.format("p=%s, q=%s, pp=%s\n", p.toString(), q.toString(), pp.toString());
192                        double e = pp.subtract(q).getNorm();
193                        errSum = errSum + sqr(e);
194                }
195                return Math.sqrt(errSum);       // correct!
196        }
197        
198        // -----------------------------------------------------------------
199        
200        private double det(RealMatrix M) {
201                return new LUDecomposition(M).getDeterminant();
202        }
203        
204        private double[] getMeanVec(Pnt2d[] points) {
205                double sumX = 0;
206                double sumY = 0;
207                for (Pnt2d p : points) {
208                        sumX = sumX + p.getX();
209                        sumY = sumY + p.getY();
210                }
211                return new double[] {sumX / points.length, sumY / points.length};
212        }
213        
214        private RealMatrix makeDataMatrix(Pnt2d[] points, double[] meanX) {
215                RealMatrix M = MatrixUtils.createRealMatrix(2, points.length);
216                RealVector mean = MatrixUtils.createRealVector(meanX);
217                int i = 0;
218                for (Pnt2d p : points) {
219                        RealVector cv = p.toRealVector();
220//                      RealVector cv = MatrixUtils.createRealVector(p.toDoubleArray());
221                        if (meanX != null) {
222                                cv = cv.subtract(mean);
223                        }
224                        M.setColumnVector(i, cv);
225                        i++;
226                }
227                return M;
228        }
229        
230//      private void printSVD(SingularValueDecomposition svd) {
231//              RealMatrix U = svd.getU();
232//              RealMatrix S = svd.getS();
233//              RealMatrix V = svd.getV();
234//              System.out.println("------ SVD ---------------");
235//              System.out.println("U = " + Matrix.toString(U.getData()));
236//              System.out.println("S = " + Matrix.toString(S.getData()));
237//              System.out.println("V = " + Matrix.toString(V.getData()));
238//              System.out.println("--------------------------");
239//      }
240        
241        
242        private void checkSize(Pnt2d[] P, Pnt2d[] Q) {
243                if (P.length < 3 || Q.length < 3) {
244                        throw new IllegalArgumentException("At least 3 point pairs are required to calculate this fit");
245                }
246        }
247
248        
249        
250        private static double roundToDigits(double x, int ndigits) {
251                int d = (int) Math.pow(10, ndigits);
252                return Math.rint(x * d) / d;
253        }
254
255        // --------------------------------------------------------------------------------
256        
257        // public static void main(String[] args) {
258        //      PrintPrecision.set(6);
259        //      int NDIGITS = 1;
260        //
261        //      boolean allowTranslation = true;
262        //      boolean allowScaling = true;
263        //      boolean forceRotation = true;
264        //
265        //      double a = 0.6;
266        //      double[][] R0data =
267        //              {{ Math.cos(a), -Math.sin(a) },
268        //               { Math.sin(a),  Math.cos(a) }};
269        //
270        //      RealMatrix R0 = MatrixUtils.createRealMatrix(R0data);
271        //      double[] t0 = {4, -3};
272        //      double s = 3.5;
273        //
274        //      System.out.format("original alpha: a = %.6f\n", a);
275        //      System.out.println("original rotation: R = \n" + Matrix.toString(R0.getData()));
276        //      System.out.println("original translation: t = " + Matrix.toString(t0));
277        //      System.out.format("original scale: s = %.6f\n", s);
278        //      System.out.println();
279        //
280        //      Pnt2d[] P = {
281        //                      PntInt.from(2, 5),
282        //                      PntInt.from(7, 3),
283        //                      PntInt.from(0, 9),
284        //                      PntInt.from(5, 4)
285        //      };
286        //
287        //      Pnt2d[] Q = new Pnt2d[P.length];
288        //
289        //      for (int i = 0; i < P.length; i++) {
290        //              Pnt2d q = PntDouble.from(R0.operate(P[i].toDoubleArray()));
291        //              // noise!
292        //              double qx = roundToDigits(s * q.getX() + t0[0], NDIGITS);
293        //              double qy = roundToDigits(s * q.getY() + t0[1], NDIGITS);
294        //              Q[i] = Pnt2d.PntDouble.from(qx, qy);
295        //      }
296        //
297        //      //P[0] = Point.create(2, 0);    // to provoke a large error
298        //
299        //      ProcrustesFit2d pf = new ProcrustesFit2d(P, Q, allowTranslation, allowScaling, forceRotation);
300        //
301        //      double[][] R = pf.getRotation();
302        //      System.out.format("estimated alpha: a = %.6f\n", Math.acos(R[0][0]));
303        //      System.out.println("estimated rotation: R = \n" + Matrix.toString(R));
304        //      double[] T = pf.getTranslation();
305        //      System.out.println("estimated translation: t = " + Matrix.toString(T));
306        //      System.out.format("estimated scale: s = %.6f\n", pf.getScale());
307        //
308        //      System.out.println();
309        //      System.out.format("RMS fitting error = %.6f\n", pf.getError());
310        //      System.out.format("euclidean error (test) = %.6f\n", pf.getEuclideanError(P, Q));
311        //
312        //      double[][] A = pf.getTransformationMatrix();
313        //      System.out.println("transformation matrix: A = \n" + Matrix.toString(A));
314        // }
315
316        /*
317        original alpha: a = 0.600000
318        original rotation: R = 
319        {{0.825336, -0.564642}, 
320        {0.564642, 0.825336}}
321        original translation: t = {4.000000, -3.000000}
322        original scale: s = 3.500000
323        
324        estimated alpha: a = 0.599589
325        estimated rotation: R = 
326        {{0.825568, -0.564303}, 
327        {0.564303, 0.825568}}
328        estimated translation: t = {3.980905, -3.011055}
329        estimated scale: s = 3.500560
330        
331        fitting error = 0.048079
332        euclidean error (test) = 0.048079
333        transformation matrix: A = 
334        {{2.889950, -1.975377, 3.980905}, 
335        {1.975377, 2.889950, -3.011055}}
336         */
337}
338