001/*******************************************************************************
002 * This software is provided as a supplement to the authors' textbooks on digital
003 * image processing published by Springer-Verlag in various languages and editions.
004 * Permission to use and distribute this software is granted under the BSD 2-Clause
005 * "Simplified" License (see http://opensource.org/licenses/BSD-2-Clause).
006 * Copyright (c) 2006-2023 Wilhelm Burger, Mark J. Burge. All rights reserved.
007 * Visit https://imagingbook.com for additional details.
008 ******************************************************************************/
009package imagingbook.common.ransac;
010
011import ij.process.ByteProcessor;
012import imagingbook.common.geometry.basic.Pnt2d;
013import imagingbook.common.geometry.basic.Primitive2d;
014import imagingbook.common.ij.DialogUtils.DialogLabel;
015import imagingbook.common.ij.IjUtils;
016import imagingbook.common.util.ParameterBundle;
017
018import java.util.ArrayList;
019import java.util.List;
020import java.util.Random;
021
022/**
023 * <p>
024 * Generic RANSAC detector for 2D primitives. See Sec. 12.1 of [1] for additional details. This abstract class defines
025 * the core RANSAC functionality used by all derived (concrete) classes.
026 * </p>
027 * <p>
028 * [1] W. Burger, M.J. Burge, <em>Digital Image Processing &ndash; An Algorithmic Introduction</em>, 3rd ed, Springer
029 * (2022).
030 * </p>
031 *
032 * @param <T> generic type extending {@link Primitive2d}
033 * @author WB
034 * @version 2022/11/19
035 * @see RansacLineDetector
036 * @see RansacCircleDetector
037 * @see RansacEllipseDetector
038 * @see Primitive2d
039 */
040public abstract class RansacDetector<T extends Primitive2d> {
041        
042        /**
043         * Parameters used by all RANSAC types.
044         */
045        public static class RansacParameters implements ParameterBundle<RansacDetector<?>> {
046                        
047                /** The number of iterations (random point draws) to use in each detection cycle.*/
048                @DialogLabel("Number of random draws") 
049                public int randomPointDraws = 1000;
050                
051                /** The maximum distance of any point from the curve to be considered an "inlier".*/
052                @DialogLabel("Max. inlier distance") 
053                public double maxInlierDistance = 2.0;
054                
055                /** The minimum number of inliers required for successful detection.*/
056                @DialogLabel("Min. inlier count") 
057                public int minInlierCount = 100;
058                
059                /** Set true to remove inlier points after each detection.*/
060                @DialogLabel("Remove inliers") 
061                public boolean removeInliers = true;
062                
063                /** Random seed used initialization (0 = no seed).*/
064                @DialogLabel("Random seed (0 = no seed)") 
065                public int randomSeed = 0;
066        }
067        
068        // -----------------------------------------------------------
069        
070        private final RansacParameters params;
071        private final int K;                                            // number of points to draw
072        private final RandomDraw<Pnt2d> randomDraw;     // 
073        
074        RansacDetector(int K, RansacParameters params) {
075                this.K = K;
076                this.params = params;
077                Random rand = (params.randomSeed == 0) ? null : new Random(params.randomSeed);
078                this.randomDraw = new RandomDraw<>(rand);
079        }
080        
081        // ----------------------------------------------------------
082
083        /**
084         * Performs iterative RANSAC steps on the supplied image, which is assumed to be binary (all nonzero pixels are
085         * considered input points). Extracts the point set from the image and calls {@link #detectAll(Pnt2d[], int)}.
086         *
087         * @param bp a binary image (nonzero pixels are considered points)
088         * @param maxCount the maximum number of primitives to detect
089         * @return the list of detected primitives
090         */
091        public List<RansacResult<T>> detectAll(ByteProcessor bp, int maxCount) {
092                Pnt2d[] points = IjUtils.collectNonzeroPoints(bp);
093                if (points.length == 0) {
094                        throw new IllegalArgumentException("empty point set");
095                }
096                return detectAll(points, maxCount);
097        }
098
099        /**
100         * Performs iterative RANSAC steps on the supplied point set until either no more primitive was detected or the
101         * maximum number of primitives was reached. Iteratively calls {@link #detectNext(Pnt2d[])} on the specified point
102         * set.
103         *
104         * @param points the original point set
105         * @param maxCount the maximum number of primitives to detect
106         * @return the list of detected primitives
107         */
108        public List<RansacResult<T>> detectAll(Pnt2d[] points, int maxCount) {
109                List<RansacResult<T>> primitives = new ArrayList<>();
110                int cnt = 0;
111                
112                RansacResult<T> sol = detectNext(points);
113                while (sol != null && cnt < maxCount) {
114                        primitives.add(sol);
115                        cnt = cnt + 1;
116                        sol = detectNext(points);
117                }
118                return primitives;
119        }
120
121        /**
122         * Performs a single RANSAC step on the supplied point set. If {@link RansacParameters#removeInliers} is set true,
123         * all associated inlier points are removed from the point set (by setting array elements to {@code null}).
124         *
125         * @param points an array of {@link Pnt2d} instances (modified)
126         * @return the detected primitive (of generic type T) or {@code null} if unsuccessful
127         */
128        public RansacResult<T> detectNext(Pnt2d[] points) {
129                Pnt2d[] drawInit = null;
130                double scoreInit = -1;
131                T primitiveInit = null;
132                
133                for (int i = 0; i < params.randomPointDraws; i++) {
134                        Pnt2d[] draw = drawRandomPoints(points);
135                        T primitive = fitInitial(draw);
136                        if (primitive == null) {
137                                continue;
138                        }
139                        double score = countInliers(primitive, points);
140                        if (score >= params.minInlierCount && score > scoreInit) {
141                                scoreInit = score;
142                                drawInit = draw;
143                                primitiveInit = primitive;
144                        }
145                }
146                
147                if (primitiveInit == null) {
148                        return null;
149                }
150                else {
151                        // refit the primitive to all inliers:
152                        Pnt2d[] inliers = collectInliers(primitiveInit, points);
153                        T primitiveFinal = fitFinal(inliers);   
154                        if (primitiveFinal != null)
155                                return new RansacResult<T>(drawInit, primitiveInit, primitiveFinal, scoreInit, inliers);
156                        else
157                                throw new RuntimeException("final fit failed!");
158                }
159        }
160
161        /**
162         * Randomly selects {@link #K} unique points from the supplied {@link Pnt2d} array. Inheriting classes may override
163         * this method to enforce specific constraints on the selected points (e.g., see {@link RansacLineDetector}).
164         *
165         * @param points an array of {@link Pnt2d} instances
166         * @return an array of {@link #K} unique points
167         */
168        Pnt2d[] drawRandomPoints(Pnt2d[] points) {      
169                return randomDraw.drawFrom(points, K);
170        }
171        
172        private int countInliers(T curve, Pnt2d[] points) {
173                int count = 0;
174                for (Pnt2d p : points) {
175                        if (p != null) {
176                                double d = curve.getDistance(p);
177                                if (d < params.maxInlierDistance) {
178                                        count++;
179                                }
180                        }
181                }
182                return count;
183        }
184
185        /**
186         * Find all points that are considered inliers with respect to the specified curve and the value of
187         * {@link RansacParameters#maxInlierDistance}. If {@link RansacParameters#removeInliers} is set true, these points
188         * are also removed from the original point set, otherwise they remain.
189         *
190         * @param curve
191         * @param points
192         * @return
193         */
194        private Pnt2d[] collectInliers(Primitive2d curve, Pnt2d[] points) {
195                List<Pnt2d> pList = new ArrayList<>();
196                for (int i = 0; i < points.length; i++) {
197                        Pnt2d p = points[i];
198                        if (p != null) {
199                                double d = curve.getDistance(p);
200                                if (d < params.maxInlierDistance) {
201                                        pList.add(p);
202                                        if (params.removeInliers) {
203                                                points[i] = null;
204                                        }
205                                }
206                        }
207                }
208                return pList.toArray(new Pnt2d[0]);
209        }
210        
211        // abstract methods to be implemented by specific sub-classes: -----------------------
212
213        /**
214         * Fits an initial primitive to the specified points. This abstract method must be implemented by inheriting
215         * classes, which must also specify the required number of initial points ({@link #K}).
216         *
217         * @param draw an array of exactly {@link #K} points
218         * @return a new primitive of type T
219         */
220        abstract T fitInitial(Pnt2d[] draw);
221
222        /**
223         * Fits a primitive to the specified points. This abstract method must be implemented by inheriting classes.
224         *
225         * @param inliers an array of at least {@link #K} points
226         * @return a new primitive of type T.
227         */
228        abstract T fitFinal(Pnt2d[] inliers);
229        
230}