1a6bc2a42eb68599c93003639332691680b5c4b1
[jagi.git] / src / jagi / scgi / EventServer.java
1 package jagi.scgi;
2
3 import jagi.*;
4 import jagi.event.*;
5 import java.util.*;
6 import java.util.function.*;
7 import java.util.concurrent.*;
8 import java.util.logging.*;
9 import java.io.*;
10 import java.nio.*;
11 import java.nio.channels.*;
12
13 public class EventServer implements Runnable {
14     private static final double timeout = 5;
15     private static final Logger log = Logger.getLogger("jagi.server");
16     private final ServerSocketChannel sk;
17     private final Function handler;
18     private final Driver ev = Driver.get();
19     private final ExecutorService handlers = new ThreadPoolExecutor(0, Runtime.getRuntime().availableProcessors() * 2,
20                                                                     5, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>(64),
21                                                                     tgt -> new Thread(tgt, "Request handler thread"));
22
23     public EventServer(ServerSocketChannel sk, Function handler) {
24         try {
25             sk.configureBlocking(false);
26         } catch(IOException e) {
27             throw(new RuntimeException(e));
28         }
29         this.sk = sk;
30         this.handler = handler;
31     }
32
33     public static class Request {
34         public final Map<Object, Object> env;
35         public final SocketChannel sk;
36
37         public Request(Map<Object, Object> env, SocketChannel sk) {
38             this.env = env;
39             this.sk = sk;
40         }
41
42         public void close() {
43             ArrayList<Object> cleanup = new ArrayList<>((Collection<?>)env.get("jagi.cleanup"));
44             cleanup.add(sk);
45             RuntimeException ce = null;
46             for(Object obj : cleanup) {
47                 if(obj instanceof AutoCloseable) {
48                     try {
49                         ((AutoCloseable)obj).close();
50                     } catch(Exception e) {
51                         if(ce == null)
52                             ce = new RuntimeException("error(s) occurred during cleanup");
53                         ce.addSuppressed(e);
54                     }
55                 }
56             }
57             if(ce != null)
58                 throw(ce);
59         }
60     }
61
62     protected void error(Request req, Throwable error) {
63         log.log(Level.WARNING, "uncaught exception while handling request", error);
64     }
65
66     public static abstract class ChainWatcher implements Watcher {
67         private Runnable then;
68         public ChainWatcher then(Runnable then) {this.then = then; return(this);}
69
70         public void close() {
71             if(then != null)
72                 then.run();
73         }
74     }
75
76     public static class BufferedOutput extends ChainWatcher {
77         public final SocketChannel sk;
78         public final ByteBuffer buf;
79         private double lastwrite;
80
81         public BufferedOutput(SocketChannel sk, ByteBuffer buf) {
82             this.sk = sk;
83             this.buf = buf;
84         }
85
86         public void added(Driver d) {lastwrite = d.time();}
87         public SelectableChannel channel() {return(sk);}
88         public int events() {return((buf.remaining() > 0) ? SelectionKey.OP_WRITE : -1);}
89         public double timeout() {return(lastwrite + timeout);}
90
91         public void handle(int events) throws IOException {
92             double now = Driver.current().time();
93             if((events & SelectionKey.OP_WRITE) != 0) {
94                 if(sk.write(buf) > 0)
95                     lastwrite = now;
96             }
97             if(now > lastwrite + timeout)
98                 buf.position(buf.limit());
99         }
100     }
101
102     public static class TransferOutput extends ChainWatcher {
103         public final SocketChannel sk;
104         public final ReadableByteChannel in;
105         private final ByteBuffer buf;
106         private boolean eof = false;
107         private double lastwrite;
108
109         public TransferOutput(SocketChannel sk, ReadableByteChannel in) {
110             this.sk = sk;
111             this.in = in;
112             buf = ByteBuffer.allocate(65536);
113             buf.flip();
114         }
115
116         public void added(Driver d) {lastwrite = d.time();}
117         public SelectableChannel channel() {return(sk);}
118         public int events() {return((eof && (buf.remaining() == 0)) ? -1 : SelectionKey.OP_WRITE);}
119         public double timeout() {return(lastwrite + timeout);}
120
121         public void handle(int events) throws IOException {
122             if(!eof && (buf.remaining() == 0)) {
123                 buf.rewind();
124                 while(buf.remaining() > 0) {
125                     if(in.read(buf) < 0)
126                         break;
127                 }
128             }
129             double now = Driver.current().time();
130             if((events & SelectionKey.OP_WRITE) != 0) {
131                 if(sk.write(buf) > 0)
132                     lastwrite = now;
133             }
134             if(now > lastwrite + timeout) {
135                 eof = true;
136                 buf.position(buf.limit());
137             }
138         }
139
140         public void close() {
141             try {
142                 in.close();
143             } catch(IOException e) {
144                 log.log(Level.WARNING, "failed to close transfer channel: " + in, e);
145             } finally {
146                 super.close();
147             }
148         }
149     }
150
151     public static class TransferInput extends ChainWatcher {
152         public final SocketChannel sk;
153         public final WritableByteChannel out;
154         private final ByteBuffer buf;
155         private final long max;
156         private boolean eof = false;
157         private double lastread;
158         private long cur = 0;
159
160         public TransferInput(SocketChannel sk, WritableByteChannel out, long max) {
161             this.sk = sk;
162             this.out = out;
163             this.max = max;
164             buf = ByteBuffer.allocate(65536);
165             buf.flip();
166         }
167
168         public void added(Driver d) {lastread = d.time();}
169         public SelectableChannel channel() {return(sk);}
170         public int events() {return(eof ? -1 : SelectionKey.OP_READ);}
171         public double timeout() {return(lastread + timeout);}
172
173         public void handle(int events) throws IOException {
174             double now = Driver.current().time();
175             if((events & SelectionKey.OP_READ) != 0) {
176                 buf.rewind();
177                 if(buf.remaining() > max - cur)
178                     buf.limit(buf.position() + (int)Math.min(max - cur, Integer.MAX_VALUE));
179                 int rv = sk.read(buf);
180                 if(rv < 0) {
181                     eof = true;
182                 } else if(rv > 0) {
183                     lastread = now;
184                     cur += rv;
185                 }
186                 buf.flip();
187                 while(buf.remaining() > 0)
188                     out.write(buf);
189             }
190             if(now > lastread + timeout) {
191                 eof = true;
192                 buf.position(buf.limit());
193             }
194         }
195
196         public void close() {
197             try {
198                 out.close();
199             } catch(IOException e) {
200                 log.log(Level.WARNING, "failed to close transfer channel: " + out, e);
201             } finally {
202                 super.close();
203             }
204         }
205     }
206
207     protected void respond(Request req, String status, Map resp) {
208         Object output = resp.get("jagi.output");
209         ByteArrayOutputStream buf = new ByteArrayOutputStream();
210         try {
211             Writer head = new OutputStreamWriter(buf, Utils.UTF8);
212             head.write("Status: ");
213             head.write(status);
214             head.write("\n");
215             for(Iterator it = resp.entrySet().iterator(); it.hasNext();) {
216                 Map.Entry ent = (Map.Entry)it.next();
217                 Object val = ent.getValue();
218                 if((ent.getKey() instanceof String) && (val != null)) {
219                     String key = (String)ent.getKey();
220                     if(key.startsWith("http.")) {
221                         String nm = key.substring(5);
222                         if(nm.equalsIgnoreCase("status"))
223                             continue;
224                         if(val instanceof Collection) {
225                             for(Object part : (Collection)val) {
226                                 head.write(nm);
227                                 head.write(": ");
228                                 head.write(part.toString());
229                                 head.write("\n");
230                             }
231                         } else {
232                             head.write(nm);
233                             head.write(": ");
234                             head.write(val.toString());
235                             head.write("\n");
236                         }
237                     }
238                 }
239             }
240             head.write("\n");
241             head.flush();
242         } catch(IOException e) {
243             throw(new RuntimeException("cannot happen"));
244         }
245         ChainWatcher out;
246         if(output == null) {
247             out = new BufferedOutput(req.sk, ByteBuffer.allocate(0));
248         } else if(output instanceof byte[]) {
249             out = new BufferedOutput(req.sk, ByteBuffer.wrap((byte[])output));
250         } else if(output instanceof ByteBuffer) {
251             out = new BufferedOutput(req.sk, (ByteBuffer)output);
252         } else if(output instanceof String) {
253             out = new BufferedOutput(req.sk, ByteBuffer.wrap(((String)output).getBytes(Utils.UTF8)));
254         } else if(output instanceof CharSequence) {
255             out = new BufferedOutput(req.sk, Utils.UTF8.encode(CharBuffer.wrap((CharSequence)output)));
256         } else if(output instanceof InputStream) {
257             out = new TransferOutput(req.sk, Channels.newChannel((InputStream)output));
258         } else if(output instanceof ReadableByteChannel) {
259             out = new TransferOutput(req.sk, (ReadableByteChannel)output);
260         } else {
261             throw(new IllegalArgumentException("response-body: " + output));
262         }
263         out.then(() -> submit(req::close));
264         ev.add(new BufferedOutput(req.sk, ByteBuffer.wrap(buf.toByteArray())).then(() -> ev.add(out)));
265     }
266
267     @SuppressWarnings("unchecked")
268     protected void handle(Request req, Function handler) {
269         boolean handoff = false;
270         try {
271             Throwable error = null;
272             try {
273                 Map resp = (Map)handler.apply(req.env);
274                 String st;
275                 if((st = (String)resp.get("jagi.status")) != null) {
276                     Function next = (Function)resp.get("jagi.next");
277                     switch(st) {
278                     case "feed-input":
279                         Object sink = resp.get("jagi.input-sink");
280                         Object clen = req.env.get("HTTP_CONTENT_LENGTH");
281                         long max = 0;
282                         if(clen instanceof String) {
283                             try {
284                                 max = Long.parseLong((String)clen);
285                             } catch(NumberFormatException e) {
286                             }
287                         }
288                         if(sink instanceof WritableByteChannel) {
289                             ev.add(new TransferInput(req.sk, (WritableByteChannel)sink, max).then(() -> submit(() -> handle(req, next))));
290                         } else if(sink instanceof OutputStream) {
291                             ev.add(new TransferInput(req.sk, Channels.newChannel((OutputStream)sink), max).then(() -> submit(() -> handle(req, next))));
292                         } else {
293                             throw(new IllegalArgumentException("input-sink: " + sink));
294                         }
295                         handoff = true;
296                         break;
297                     case "chain":
298                         submit(() -> handle(req, next));
299                         handoff = true;
300                         break;
301                     default:
302                         throw(new IllegalArgumentException("jagi.status: " + st));
303                     }
304                 } else if((st = (String)resp.get("http.status")) != null) {
305                     respond(req, st, resp);
306                     handoff = true;
307                 } else {
308                     throw(new IllegalArgumentException("neither http.status nor jagi.status set"));
309                 }
310             } catch(Throwable t) {
311                 error = t;
312                 throw(t);
313             } finally {
314                 if(!handoff) {
315                     try {
316                         req.close();
317                     } catch(Throwable ce) {
318                         if(error == null) {
319                             throw(ce);
320                         } else {
321                             error.addSuppressed(ce);
322                         }
323                     }
324                 }
325             }
326         } catch(Throwable t) {
327             error(req, t);
328         }
329     }
330
331     protected void submit(Runnable task) {
332         handlers.submit(task);
333     }
334
335     class Client implements Watcher {
336         final SocketChannel sk;
337         double lastread;
338         boolean eof = false, handoff = false;
339         int headlen = 0;
340         ByteBuffer head = null;
341         Map<Object, Object> env = null;
342
343         Client(SocketChannel sk) {
344             this.sk = sk;
345         }
346
347         public void added(Driver d) {lastread = d.time();}
348         public SelectableChannel channel() {return(sk);}
349         public double timeout() {return(lastread + timeout);}
350         public int events() {
351             if(eof)
352                 return(-1);
353             if(env == null)
354                 return(SelectionKey.OP_READ);
355             return(-1);
356         }
357
358         boolean readhead() throws IOException {
359             if(head == null) {
360                 ByteBuffer buf = ByteBuffer.allocate(1);
361                 while(true) {
362                     buf.rewind();
363                     int rv = sk.read(buf);
364                     if(rv < 0) {
365                         eof = true;
366                         return(false);
367                     } else if(rv == 0) {
368                         return(false);
369                     } else {
370                         lastread = Driver.current().time();
371                         int c = buf.get(0);
372                         if((c >= '0') && (c <= '9')) {
373                             headlen = (headlen * 10) + (c - '0');
374                         } else if(c == ':') {
375                             head = ByteBuffer.allocate(headlen + 1);
376                             break;
377                         } else {
378                             eof = true;
379                             return(false);
380                         }
381                     }
382                 }
383             }
384             while(true) {
385                 if(head.remaining() == 0) {
386                     if(head.get(head.limit() - 1) != ',') {
387                         /* Unterminated netstring */
388                         eof = true;
389                         return(false);
390                     }
391                     head.limit(head.limit() - 1);
392                     env = Jagi.mkenv(Scgi.splithead(head), sk);
393                     return(true);
394                 }
395                 int rv = sk.read(head);
396                 if(rv < 0) {
397                     eof = true;
398                     return(false);
399                 } else if(rv == 0) {
400                     return(false);
401                 }
402             }
403         }
404
405         public void handle(int events) throws IOException {
406             if((events & SelectionKey.OP_READ) != 0) {
407                 if((env == null) && !readhead())
408                     return;
409                 Request req = new Request(env, sk);
410                 submit(() -> EventServer.this.handle(req, handler));
411                 handoff = true;
412             }
413             if(Driver.current().time() > (lastread + timeout))
414                 eof = true;
415         }
416
417         public void close() {
418             if(!handoff) {
419                 try {
420                     sk.close();
421                 } catch(IOException e) {
422                 }
423             }
424         }
425     }
426
427     class Accepter implements Watcher {
428         boolean closed = false;
429
430         public SelectableChannel channel() {return(sk);}
431         public int events() {return(SelectionKey.OP_ACCEPT);}
432
433         public void handle(int events) throws IOException {
434             if((events & SelectionKey.OP_ACCEPT) != 0) {
435                 SocketChannel cl = sk.accept();
436                 cl.configureBlocking(false);
437                 Driver.current().add(new Client(cl));
438             }
439         }
440
441         public void close() {
442             synchronized(this) {
443                 closed = true;
444                 notifyAll();
445             }
446         }
447     }
448
449     public void run() {
450         Accepter main = new Accepter();
451         ev.add(main);
452         try {
453             synchronized(main) {
454                 while(!main.closed) {
455                     main.wait();
456                 }
457             }
458         } catch(InterruptedException e) {
459             ev.remove(main);
460         } finally {
461             try {
462                 sk.close();
463             } catch(IOException e) {
464                 throw(new RuntimeException(e));
465             }
466         }
467     }
468 }