001/*******************************************************************************
002 * This software is provided as a supplement to the authors' textbooks on digital
003 * image processing published by Springer-Verlag in various languages and editions.
004 * Permission to use and distribute this software is granted under the BSD 2-Clause
005 * "Simplified" License (see http://opensource.org/licenses/BSD-2-Clause).
006 * Copyright (c) 2006-2023 Wilhelm Burger, Mark J. Burge. All rights reserved.
007 * Visit https://imagingbook.com for additional details.
008 ******************************************************************************/
009package imagingbook.common.geometry.fitting.line;
010
011import imagingbook.common.geometry.basic.Pnt2d;
012import imagingbook.common.geometry.line.AlgebraicLine;
013import imagingbook.common.geometry.line.SlopeInterceptLine;
014
015import static imagingbook.common.math.Arithmetic.sqr;
016
017/**
018 * <p>
019 * This class implements line fitting by linear regression to a set of 2D points. See Sec. 10.2.1 of [1] for additional
020 * details.
021 * </p>
022 * <p>
023 * [1] W. Burger, M.J. Burge, <em>Digital Image Processing &ndash; An Algorithmic Introduction</em>, 3rd ed, Springer
024 * (2022).
025 * </p>
026 *
027 * @author WB
028 * @version 2022/09/29
029 */
030public class LinearRegressionFit implements LineFit {
031        
032        private final int n;
033        private final double[] p;       // line parameters A,B,C
034        private double k, d;
035
036        /**
037         * Constructor, performs a linear regression fit to the specified points. At least two different points are
038         * required.
039         *
040         * @param points an array with at least 2 points
041         */
042        public LinearRegressionFit(Pnt2d[] points) {
043                if (points.length < 2) {
044                        throw new IllegalArgumentException("line fit requires at least 2 points");
045                }
046                this.n = points.length;
047                this.p = fit(points);
048        }
049        
050        @Override
051        public int getSize() {
052                return n;
053        }
054
055        @Override
056        public double[] getLineParameters() {
057                return p;
058        }
059
060        /**
061         * Returns the slope parameter k for the fitted line y = k * x + d.
062         *
063         * @return line parameter k
064         */
065        public double getK() {
066                return k;
067        }
068
069        /**
070         * Returns the intercept parameter d for the fitted line y = k * x + d.
071         *
072         * @return line parameter d
073         */
074        public double getD() {
075                return d;
076        }
077        
078        // ----------------------------------------------------------------------
079        
080        private double[] fit(Pnt2d[] points) {
081                final int n = points.length;
082        
083                double Sx = 0, Sy = 0, Sxx = 0, Sxy = 0;
084                
085                for (int i = 0; i < n; i++) {
086                        final double x = points[i].getX();
087                        final double y = points[i].getY();
088                        Sx += x;
089                        Sy += y;
090                        Sxx += sqr(x);
091                        Sxy += x * y;
092                }
093                        
094                double den = sqr(Sx) - n * Sxx;
095                this.k = (Sx * Sy - n * Sxy) / den;
096                this.d = (Sx * Sxy - Sxx * Sy) / den;
097                
098                AlgebraicLine line = AlgebraicLine.from(new SlopeInterceptLine(k, d));
099                return line.getParameters();
100        }
101
102
103        /**
104         * Calculates and returns the sum of the squared differences between
105         * the y-coordinates of the data points (xi, yi) and the associated y-value
106         * of the regression line (y = k x + d).
107         * 
108         * @param points an array of 2D points
109         * @return the squared regression error
110         */
111        public double getSquaredRegressionError(Pnt2d[] points) {
112                double s2 = 0;
113                for (Pnt2d p : points) {
114                        double y = k * p.getX() + d;
115                        s2 = s2 + sqr(y - p.getY());
116                }
117                return s2;
118        }
119        
120}