Fix potential invalid select timeouts.
[jagi.git] / src / jagi / event / Driver.java
1 package jagi.event;
2
3 import java.util.*;
4 import java.util.logging.*;
5 import java.util.concurrent.*;
6 import java.io.*;
7 import java.nio.*;
8 import java.nio.channels.*;
9 import java.nio.channels.spi.*;
10
11 public class Driver {
12     private static final Logger log = Logger.getLogger("jagi.event");
13     private static final Logger hlog = Logger.getLogger("jagi.event.handler");
14     private static final ThreadLocal<Driver> current = new ThreadLocal<>();
15     private final Map<SelectorProvider, SelectPool> selectors = new HashMap<>();
16     private final ExecutorService worker = new ThreadPoolExecutor(0, Runtime.getRuntime().availableProcessors(),
17                                                                   5, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>(128),
18                                                                   this::thread);
19
20     protected Thread thread(Runnable tgt) {
21         return(new Thread(tgt));
22     }
23
24     protected void handle(Watcher w, int evs) {
25         try {
26             current.set(this);
27             w.handle(evs);
28         } catch(Throwable t) {
29             error(w, t, "handling event");
30         } finally {
31             current.remove();
32         }
33     }
34
35     protected void close(Watcher w) {
36         try {
37             current.set(this);
38             w.close();
39         } catch(Throwable t) {
40             error(w, t, "closing");
41         } finally {
42             current.remove();
43         }
44     }
45
46     protected void submit(Runnable task) {
47         worker.submit(task);
48     }
49
50     protected void error(Watcher w, Throwable t, String thing) {
51         hlog.log(Level.WARNING, w + ": uncaught error when " + thing, t);
52         remove(w);
53     }
54
55     class SelectPool implements Runnable {
56         final SelectorProvider provider;
57         final Selector poll;
58         final Map<Watcher, SelectionKey> watching = new IdentityHashMap<>();
59         final Heap<Watcher, Double> timeheap = new Heap<>(Comparator.naturalOrder());
60         final Map<Watcher, Object> paused = new IdentityHashMap<>();
61
62         SelectPool(SelectorProvider provider) {
63             this.provider = provider;
64             try {
65                 this.poll = provider.openSelector();
66             } catch(IOException e) {
67                 /* I think this counts more as an assertion error. */
68                 throw(new RuntimeException(e));
69             }
70         }
71
72         void handle(Watcher w, int evs) {
73             if(!watching.containsKey(w))
74                 return;
75             try {
76                 pause(w);
77                 submit(() -> {
78                         try {
79                             Driver.this.handle(w, evs);
80                         } finally {
81                             resume(w);
82                         }
83                     });
84             } catch(Throwable t) {
85                 try {
86                     synchronized(selectors) {
87                         remove(w);
88                     }
89                 } catch(Exception e) {
90                     t.addSuppressed(e);
91                 }
92                 log.log(Level.SEVERE, "unexpected error when submitting event", t);
93             }
94         }
95
96         void start() {
97             thread(this).start();
98         }
99
100         public void run() {
101             boolean quit = false;
102             Throwable error = null;
103             try {
104                 double now = time();
105                 while(true) {
106                     long timeout = 0;
107                     synchronized(selectors) {
108                         Double first = timeheap.keypeek();
109                         if((first == null) && watching.isEmpty()) {
110                             quit = true;
111                             selectors.remove(provider);
112                             return;
113                         }
114                         if(first != null)
115                             timeout = Math.max((long)Math.ceil((first - now) * 1000), 0);
116                     }
117                     poll.selectedKeys().clear();
118                     try {
119                         poll.select(timeout);
120                     } catch(IOException e) {
121                         throw(new RuntimeException(e));
122                     }
123                     for(SelectionKey key : poll.selectedKeys())
124                         handle((Watcher)key.attachment(), key.readyOps());
125                     now = time();
126                     while(true) {
127                         Double first = timeheap.keypeek();
128                         if((first == null) || (first > now))
129                             break;
130                         handle(timeheap.remove(), 0);
131                     }
132                 }
133             } catch(Throwable t) {
134                 error = t;
135                 throw(t);
136             } finally {
137                 if(!quit)
138                     log.log(Level.SEVERE, "selector exited abnormally", error);
139             }
140         }
141
142         void pause(Watcher w) {
143             if(paused.containsKey(w))
144                 throw(new IllegalStateException(w + ": already paused"));
145             SelectionKey wc = watching.get(w);
146             Object tc = timeheap.remove(w);
147             if((wc == null) && (tc == null))
148                 throw(new IllegalStateException(w + ": not registered"));
149             if(wc != null)
150                 wc.interestOps(0);
151             paused.put(w, this);
152         }
153
154         void resume(Watcher w) {
155             if(paused.remove(w) == null)
156                 return;
157             SelectionKey wc = watching.get(w);
158             int evs = w.events();
159             double timeout = w.timeout();
160             boolean hastime = timeout < Double.POSITIVE_INFINITY;
161             if(evs < 0) {
162                 remove(w);
163                 return;
164             }
165             wc.interestOps(evs);
166             if(hastime)
167                 timeheap.add(w, timeout);
168             poll.wakeup();
169         }
170
171         void add(Watcher w, SelectableChannel ch) {
172             if(watching.containsKey(w) || paused.containsKey(w) || timeheap.contains(w))
173                 throw(new IllegalStateException(w + ": already registered"));
174             int evs = w.events();
175             double timeout = w.timeout();
176             boolean hastime = timeout < Double.POSITIVE_INFINITY;
177             if(evs < 0) {
178                 submit(() -> close(w));
179                 return;
180             }
181             w.added(Driver.this);
182             try {
183                 watching.put(w, ch.register(poll, evs, w));
184             } catch(ClosedChannelException e) {
185                 throw(new RuntimeException("attempted to watch closed channel", e));
186             }
187             if(hastime)
188                 timeheap.add(w, timeout);
189             poll.wakeup();
190         }
191
192         void remove(Watcher w) {
193             SelectionKey wc = watching.remove(w);
194             Object tc = timeheap.remove(w);
195             Object pc = paused.remove(w);
196             if(wc != null)
197                 wc.cancel();
198             if(((wc != null) || (tc != null)) && (pc != null))
199                 throw(new RuntimeException(w + ": inconsistent internal state"));
200             if(wc == null)
201                 throw(new IllegalStateException(w + ": not registered"));
202             submit(() -> close(w));
203             poll.wakeup();
204         }
205
206         void update(Watcher w) {
207             SelectionKey wc = watching.get(w);
208             if(wc == null)
209                 throw(new IllegalStateException(w + ": not registered"));
210             int evs = w.events();
211             double timeout = w.timeout();
212             boolean hastime = timeout < Double.POSITIVE_INFINITY;
213             if(evs < 0) {
214                 remove(w);
215                 return;
216             }
217             wc.interestOps(evs);
218             if(hastime)
219                 timeheap.set(w, timeout);
220             else
221                 timeheap.remove(w);
222             poll.wakeup();
223         }
224     }
225
226     private SelectPool pool(SelectorProvider provider) {
227         SelectPool pool = selectors.get(provider);
228         if(pool == null) {
229             pool = new SelectPool(provider);
230             selectors.put(provider, pool);
231             pool.start();
232         }
233         return(pool);
234     }
235
236     public void add(Watcher w) {
237         SelectableChannel ch = w.channel();
238         synchronized(selectors) {
239             pool(ch.provider()).add(w, ch);
240         }
241     }
242
243     public void remove(Watcher w) {
244         SelectableChannel ch = w.channel();
245         synchronized(selectors) {
246             pool(ch.provider()).remove(w);
247         }
248     }
249
250     public void update(Watcher w) {
251         SelectableChannel ch = w.channel();
252         synchronized(selectors) {
253             pool(ch.provider()).update(w);
254         }
255     }
256
257     public double time() {
258         return(rtime());
259     }
260
261     private static final long rtimeoff = System.nanoTime();
262     public static double rtime() {
263         return((System.nanoTime() - rtimeoff) / 1e9);
264     }
265
266     private static Driver global = null;
267     public static Driver get() {
268         if(global == null) {
269             synchronized(Driver.class) {
270                 if(global == null)
271                     global = new Driver();
272             }
273         }
274         return(global);
275     }
276
277     public static Driver current() {
278         Driver ret = current.get();
279         if(ret == null)
280             throw(new IllegalStateException("no current driver"));
281         return(ret);
282     }
283 }