Fixed some event-server channel-transfer bugs.
[jagi.git] / src / jagi / scgi / EventServer.java
1 package jagi.scgi;
2
3 import jagi.*;
4 import jagi.event.*;
5 import java.util.*;
6 import java.util.function.*;
7 import java.util.concurrent.*;
8 import java.util.logging.*;
9 import java.io.*;
10 import java.nio.*;
11 import java.nio.channels.*;
12
13 public class EventServer implements Runnable {
14     private static final double timeout = 5;
15     private static final Logger log = Logger.getLogger("jagi.server");
16     private final ServerSocketChannel sk;
17     private final Function handler;
18     private final Driver ev = Driver.get();
19     private final ExecutorService handlers = new ThreadPoolExecutor(0, Runtime.getRuntime().availableProcessors() * 2,
20                                                                     5, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>(64),
21                                                                     tgt -> new Thread(tgt, "Request handler thread"));
22
23     public EventServer(ServerSocketChannel sk, Function handler) {
24         this.sk = sk;
25         this.handler = handler;
26     }
27
28     public static class Request {
29         public final Map<Object, Object> env;
30         public final SocketChannel sk;
31
32         public Request(Map<Object, Object> env, SocketChannel sk) {
33             this.env = env;
34             this.sk = sk;
35         }
36
37         public void close() {
38             ArrayList<Object> cleanup = new ArrayList<>((Collection<?>)env.get("jagi.cleanup"));
39             cleanup.add(sk);
40             RuntimeException ce = null;
41             for(Object obj : cleanup) {
42                 if(obj instanceof AutoCloseable) {
43                     try {
44                         ((AutoCloseable)obj).close();
45                     } catch(Exception e) {
46                         if(ce == null)
47                             ce = new RuntimeException("error(s) occurred during cleanup");
48                         ce.addSuppressed(e);
49                     }
50                 }
51             }
52             if(ce != null)
53                 throw(ce);
54         }
55     }
56
57     protected void error(Request req, Throwable error) {
58         log.log(Level.WARNING, "uncaught exception while handling request", error);
59     }
60
61     public static abstract class ChainWatcher implements Watcher {
62         private Runnable then;
63         public ChainWatcher then(Runnable then) {this.then = then; return(this);}
64
65         public void close() {
66             if(then != null)
67                 then.run();
68         }
69     }
70
71     public static class BufferedOutput extends ChainWatcher {
72         public final SocketChannel sk;
73         public final ByteBuffer buf;
74         private double lastwrite;
75
76         public BufferedOutput(SocketChannel sk, ByteBuffer buf) {
77             this.sk = sk;
78             this.buf = buf;
79         }
80
81         public void added(Driver d) {lastwrite = d.time();}
82         public SelectableChannel channel() {return(sk);}
83         public int events() {return((buf.remaining() > 0) ? SelectionKey.OP_WRITE : -1);}
84         public double timeout() {return(lastwrite + timeout);}
85
86         public void handle(int events) throws IOException {
87             double now = Driver.current().time();
88             if((events & SelectionKey.OP_WRITE) != 0) {
89                 if(sk.write(buf) > 0)
90                     lastwrite = now;
91             }
92             if(now > lastwrite + timeout)
93                 buf.position(buf.limit());
94         }
95     }
96
97     public static class TransferOutput extends ChainWatcher {
98         public final SocketChannel sk;
99         public final ReadableByteChannel in;
100         private final ByteBuffer buf;
101         private boolean eof = false;
102         private double lastwrite;
103
104         public TransferOutput(SocketChannel sk, ReadableByteChannel in) {
105             this.sk = sk;
106             this.in = in;
107             buf = ByteBuffer.allocate(65536);
108             buf.flip();
109         }
110
111         public void added(Driver d) {lastwrite = d.time();}
112         public SelectableChannel channel() {return(sk);}
113         public int events() {return((eof && (buf.remaining() == 0)) ? -1 : SelectionKey.OP_WRITE);}
114         public double timeout() {return(lastwrite + timeout);}
115
116         public void handle(int events) throws IOException {
117             if(!eof && (buf.remaining() == 0)) {
118                 buf.clear();
119                 while(buf.remaining() > 0) {
120                     if(in.read(buf) < 0)
121                         break;
122                 }
123             }
124             double now = Driver.current().time();
125             if((events & SelectionKey.OP_WRITE) != 0) {
126                 if(sk.write(buf) > 0)
127                     lastwrite = now;
128             }
129             if(now > lastwrite + timeout) {
130                 eof = true;
131                 buf.position(buf.limit());
132             }
133         }
134
135         public void close() {
136             try {
137                 in.close();
138             } catch(IOException e) {
139                 log.log(Level.WARNING, "failed to close transfer channel: " + in, e);
140             } finally {
141                 super.close();
142             }
143         }
144     }
145
146     public static class TransferInput extends ChainWatcher {
147         public final SocketChannel sk;
148         public final WritableByteChannel out;
149         private final ByteBuffer buf;
150         private final long max;
151         private boolean eof = false;
152         private double lastread;
153         private long cur = 0;
154
155         public TransferInput(SocketChannel sk, WritableByteChannel out, long max) {
156             this.sk = sk;
157             this.out = out;
158             this.max = max;
159             buf = ByteBuffer.allocate(65536);
160             buf.flip();
161         }
162
163         public void added(Driver d) {lastread = d.time();}
164         public SelectableChannel channel() {return(sk);}
165         public int events() {return(eof ? -1 : SelectionKey.OP_READ);}
166         public double timeout() {return(lastread + timeout);}
167
168         public void handle(int events) throws IOException {
169             double now = Driver.current().time();
170             if((events & SelectionKey.OP_READ) != 0) {
171                 buf.clear();
172                 if(buf.remaining() > max - cur)
173                     buf.limit(buf.position() + (int)Math.min(max - cur, Integer.MAX_VALUE));
174                 int rv = sk.read(buf);
175                 if(rv < 0) {
176                     eof = true;
177                 } else if(rv > 0) {
178                     lastread = now;
179                     if((cur += rv) >= max)
180                         eof = true;
181                 }
182                 buf.flip();
183                 while(buf.remaining() > 0)
184                     out.write(buf);
185             }
186             if(now > lastread + timeout) {
187                 eof = true;
188                 buf.position(buf.limit());
189             }
190         }
191
192         public void close() {
193             try {
194                 out.close();
195             } catch(IOException e) {
196                 log.log(Level.WARNING, "failed to close transfer channel: " + out, e);
197             } finally {
198                 super.close();
199             }
200         }
201     }
202
203     protected void respond(Request req, String status, Map resp) {
204         Object output = resp.get("jagi.output");
205         ByteArrayOutputStream buf = new ByteArrayOutputStream();
206         try {
207             Writer head = new OutputStreamWriter(buf, Utils.UTF8);
208             head.write("Status: ");
209             head.write(status);
210             head.write("\n");
211             for(Iterator it = resp.entrySet().iterator(); it.hasNext();) {
212                 Map.Entry ent = (Map.Entry)it.next();
213                 Object val = ent.getValue();
214                 if((ent.getKey() instanceof String) && (val != null)) {
215                     String key = (String)ent.getKey();
216                     if(key.startsWith("http.")) {
217                         String nm = key.substring(5);
218                         if(nm.equalsIgnoreCase("status"))
219                             continue;
220                         if(val instanceof Collection) {
221                             for(Object part : (Collection)val) {
222                                 head.write(nm);
223                                 head.write(": ");
224                                 head.write(part.toString());
225                                 head.write("\n");
226                             }
227                         } else {
228                             head.write(nm);
229                             head.write(": ");
230                             head.write(val.toString());
231                             head.write("\n");
232                         }
233                     }
234                 }
235             }
236             head.write("\n");
237             head.flush();
238         } catch(IOException e) {
239             throw(new RuntimeException("cannot happen"));
240         }
241         ChainWatcher out;
242         if(output == null) {
243             out = new BufferedOutput(req.sk, ByteBuffer.allocate(0));
244         } else if(output instanceof byte[]) {
245             out = new BufferedOutput(req.sk, ByteBuffer.wrap((byte[])output));
246         } else if(output instanceof ByteBuffer) {
247             out = new BufferedOutput(req.sk, (ByteBuffer)output);
248         } else if(output instanceof String) {
249             out = new BufferedOutput(req.sk, ByteBuffer.wrap(((String)output).getBytes(Utils.UTF8)));
250         } else if(output instanceof CharSequence) {
251             out = new BufferedOutput(req.sk, Utils.UTF8.encode(CharBuffer.wrap((CharSequence)output)));
252         } else if(output instanceof InputStream) {
253             out = new TransferOutput(req.sk, Channels.newChannel((InputStream)output));
254         } else if(output instanceof ReadableByteChannel) {
255             out = new TransferOutput(req.sk, (ReadableByteChannel)output);
256         } else {
257             throw(new IllegalArgumentException("response-body: " + output));
258         }
259         out.then(() -> submit(req::close));
260         ev.add(new BufferedOutput(req.sk, ByteBuffer.wrap(buf.toByteArray())).then(() -> ev.add(out)));
261     }
262
263     @SuppressWarnings("unchecked")
264     protected void handle(Request req, Function handler) {
265         boolean handoff = false;
266         try {
267             Throwable error = null;
268             try {
269                 Map resp = (Map)handler.apply(req.env);
270                 String st;
271                 if((st = (String)resp.get("jagi.status")) != null) {
272                     Function next = (Function)resp.get("jagi.next");
273                     switch(st) {
274                     case "feed-input":
275                         Object sink = resp.get("jagi.input-sink");
276                         Object clen = req.env.get("HTTP_CONTENT_LENGTH");
277                         long max = 0;
278                         if(clen instanceof String) {
279                             try {
280                                 max = Long.parseLong((String)clen);
281                             } catch(NumberFormatException e) {
282                             }
283                         }
284                         if(sink instanceof WritableByteChannel) {
285                             ev.add(new TransferInput(req.sk, (WritableByteChannel)sink, max).then(() -> submit(() -> handle(req, next))));
286                         } else if(sink instanceof OutputStream) {
287                             ev.add(new TransferInput(req.sk, Channels.newChannel((OutputStream)sink), max).then(() -> submit(() -> handle(req, next))));
288                         } else {
289                             throw(new IllegalArgumentException("input-sink: " + sink));
290                         }
291                         handoff = true;
292                         break;
293                     case "chain":
294                         submit(() -> handle(req, next));
295                         handoff = true;
296                         break;
297                     default:
298                         throw(new IllegalArgumentException("jagi.status: " + st));
299                     }
300                 } else if((st = (String)resp.get("http.status")) != null) {
301                     respond(req, st, resp);
302                     handoff = true;
303                 } else {
304                     throw(new IllegalArgumentException("neither http.status nor jagi.status set"));
305                 }
306             } catch(Throwable t) {
307                 error = t;
308                 throw(t);
309             } finally {
310                 if(!handoff) {
311                     try {
312                         req.close();
313                     } catch(Throwable ce) {
314                         if(error == null) {
315                             throw(ce);
316                         } else {
317                             error.addSuppressed(ce);
318                         }
319                     }
320                 }
321             }
322         } catch(Throwable t) {
323             error(req, t);
324         }
325     }
326
327     protected void submit(Runnable task) {
328         handlers.submit(task);
329     }
330
331     class Client implements Watcher {
332         final SocketChannel sk;
333         double lastread;
334         boolean eof = false, handoff = false;
335         int headlen = 0;
336         ByteBuffer head = null;
337         Map<Object, Object> env = null;
338         Request req = null;
339
340         Client(SocketChannel sk) {
341             this.sk = sk;
342         }
343
344         public void added(Driver d) {lastread = d.time();}
345         public SelectableChannel channel() {return(sk);}
346         public double timeout() {return(lastread + timeout);}
347         public int events() {
348             if(eof)
349                 return(-1);
350             if(env == null)
351                 return(SelectionKey.OP_READ);
352             return(-1);
353         }
354
355         boolean readhead() throws IOException {
356             if(head == null) {
357                 ByteBuffer buf = ByteBuffer.allocate(1);
358                 while(true) {
359                     buf.rewind();
360                     int rv = sk.read(buf);
361                     if(rv < 0) {
362                         eof = true;
363                         return(false);
364                     } else if(rv == 0) {
365                         return(false);
366                     } else {
367                         lastread = Driver.current().time();
368                         int c = buf.get(0);
369                         if((c >= '0') && (c <= '9')) {
370                             headlen = (headlen * 10) + (c - '0');
371                         } else if(c == ':') {
372                             head = ByteBuffer.allocate(headlen + 1);
373                             break;
374                         } else {
375                             eof = true;
376                             return(false);
377                         }
378                     }
379                 }
380             }
381             while(true) {
382                 if(head.remaining() == 0) {
383                     if(head.get(head.limit() - 1) != ',') {
384                         /* Unterminated netstring */
385                         eof = true;
386                         return(false);
387                     }
388                     head.limit(head.limit() - 1);
389                     env = Jagi.mkenv(Scgi.splithead(head), sk);
390                     return(true);
391                 }
392                 int rv = sk.read(head);
393                 if(rv < 0) {
394                     eof = true;
395                     return(false);
396                 } else if(rv == 0) {
397                     return(false);
398                 }
399             }
400         }
401
402         public void handle(int events) throws IOException {
403             if((events & SelectionKey.OP_READ) != 0) {
404                 if((env == null) && !readhead())
405                     return;
406                 req = new Request(env, sk);
407                 handoff = true;
408             }
409             if(Driver.current().time() > (lastread + timeout))
410                 eof = true;
411         }
412
413         public void close() {
414             if(req != null)
415                 submit(() -> EventServer.this.handle(req, handler));
416             if(!handoff) {
417                 try {
418                     sk.close();
419                 } catch(IOException e) {
420                 }
421             }
422         }
423     }
424
425     class Accepter implements Watcher {
426         boolean closed = false;
427
428         public SelectableChannel channel() {return(sk);}
429         public int events() {return(SelectionKey.OP_ACCEPT);}
430
431         public void handle(int events) throws IOException {
432             if((events & SelectionKey.OP_ACCEPT) != 0)
433                 Driver.current().add(new Client(sk.accept()));
434         }
435
436         public void close() {
437             synchronized(this) {
438                 closed = true;
439                 notifyAll();
440             }
441         }
442     }
443
444     public void run() {
445         Accepter main = new Accepter();
446         ev.add(main);
447         try {
448             synchronized(main) {
449                 while(!main.closed) {
450                     main.wait();
451                 }
452             }
453         } catch(InterruptedException e) {
454             ev.remove(main);
455         } finally {
456             try {
457                 sk.close();
458             } catch(IOException e) {
459                 throw(new RuntimeException(e));
460             }
461         }
462     }
463 }