202b8ddad9f109b70bd9dff9e714fa5a292dcd25
[jagi.git] / src / jagi / scgi / EventServer.java
1 package jagi.scgi;
2
3 import jagi.*;
4 import jagi.event.*;
5 import java.util.*;
6 import java.util.function.*;
7 import java.util.concurrent.*;
8 import java.util.logging.*;
9 import java.io.*;
10 import java.nio.*;
11 import java.nio.channels.*;
12
13 public class EventServer implements Runnable {
14     private static final double timeout = 5;
15     private static final Logger log = Logger.getLogger("jagi.server");
16     private final ServerSocketChannel sk;
17     private final Function handler;
18     private final Driver ev = Driver.get();
19     private final ExecutorService handlers = new ThreadPoolExecutor(0, Runtime.getRuntime().availableProcessors() * 2,
20                                                                     5, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>(64),
21                                                                     tgt -> new Thread(tgt, "Request handler thread"));
22
23     public EventServer(ServerSocketChannel sk, Function handler) {
24         this.sk = sk;
25         this.handler = handler;
26     }
27
28     public static class Request {
29         public final Map<Object, Object> env;
30         public final SocketChannel sk;
31
32         public Request(Map<Object, Object> env, SocketChannel sk) {
33             this.env = env;
34             this.sk = sk;
35         }
36
37         public void close() {
38             ArrayList<Object> cleanup = new ArrayList<>((Collection<?>)env.get("jagi.cleanup"));
39             cleanup.add(sk);
40             RuntimeException ce = null;
41             for(Object obj : cleanup) {
42                 if(obj instanceof AutoCloseable) {
43                     try {
44                         ((AutoCloseable)obj).close();
45                     } catch(Exception e) {
46                         if(ce == null)
47                             ce = new RuntimeException("error(s) occurred during cleanup");
48                         ce.addSuppressed(e);
49                     }
50                 }
51             }
52             if(ce != null)
53                 throw(ce);
54         }
55     }
56
57     protected void error(Request req, Throwable error) {
58         log.log(Level.WARNING, "uncaught exception while handling request", error);
59     }
60
61     public static abstract class ChainWatcher implements Watcher {
62         private Runnable then;
63         public ChainWatcher then(Runnable then) {this.then = then; return(this);}
64
65         public void close() {
66             if(then != null)
67                 then.run();
68         }
69     }
70
71     public static class BufferedOutput extends ChainWatcher {
72         public final SocketChannel sk;
73         public final ByteBuffer buf;
74         private double lastwrite;
75
76         public BufferedOutput(SocketChannel sk, ByteBuffer buf) {
77             this.sk = sk;
78             this.buf = buf;
79         }
80
81         public void added(Driver d) {lastwrite = d.time();}
82         public SelectableChannel channel() {return(sk);}
83         public int events() {return((buf.remaining() > 0) ? SelectionKey.OP_WRITE : -1);}
84         public double timeout() {return(lastwrite + timeout);}
85
86         public void handle(int events) throws IOException {
87             double now = Driver.current().time();
88             if((events & SelectionKey.OP_WRITE) != 0) {
89                 if(sk.write(buf) > 0)
90                     lastwrite = now;
91             }
92             if(now > lastwrite + timeout)
93                 buf.position(buf.limit());
94         }
95     }
96
97     public static class TransferOutput extends ChainWatcher {
98         public final SocketChannel sk;
99         public final ReadableByteChannel in;
100         private final ByteBuffer buf;
101         private boolean eof = false;
102         private double lastwrite;
103
104         public TransferOutput(SocketChannel sk, ReadableByteChannel in) {
105             this.sk = sk;
106             this.in = in;
107             buf = ByteBuffer.allocate(65536);
108             buf.flip();
109         }
110
111         public void added(Driver d) {lastwrite = d.time();}
112         public SelectableChannel channel() {return(sk);}
113         public int events() {return((eof && (buf.remaining() == 0)) ? -1 : SelectionKey.OP_WRITE);}
114         public double timeout() {return(lastwrite + timeout);}
115
116         public void handle(int events) throws IOException {
117             if(!eof && (buf.remaining() == 0)) {
118                 buf.rewind();
119                 while(buf.remaining() > 0) {
120                     if(in.read(buf) < 0)
121                         break;
122                 }
123             }
124             double now = Driver.current().time();
125             if((events & SelectionKey.OP_WRITE) != 0) {
126                 if(sk.write(buf) > 0)
127                     lastwrite = now;
128             }
129             if(now > lastwrite + timeout) {
130                 eof = true;
131                 buf.position(buf.limit());
132             }
133         }
134
135         public void close() {
136             try {
137                 in.close();
138             } catch(IOException e) {
139                 log.log(Level.WARNING, "failed to close transfer channel: " + in, e);
140             } finally {
141                 super.close();
142             }
143         }
144     }
145
146     public static class TransferInput extends ChainWatcher {
147         public final SocketChannel sk;
148         public final WritableByteChannel out;
149         private final ByteBuffer buf;
150         private final long max;
151         private boolean eof = false;
152         private double lastread;
153         private long cur = 0;
154
155         public TransferInput(SocketChannel sk, WritableByteChannel out, long max) {
156             this.sk = sk;
157             this.out = out;
158             this.max = max;
159             buf = ByteBuffer.allocate(65536);
160             buf.flip();
161         }
162
163         public void added(Driver d) {lastread = d.time();}
164         public SelectableChannel channel() {return(sk);}
165         public int events() {return(eof ? -1 : SelectionKey.OP_READ);}
166         public double timeout() {return(lastread + timeout);}
167
168         public void handle(int events) throws IOException {
169             double now = Driver.current().time();
170             if((events & SelectionKey.OP_READ) != 0) {
171                 buf.rewind();
172                 if(buf.remaining() > max - cur)
173                     buf.limit(buf.position() + (int)Math.min(max - cur, Integer.MAX_VALUE));
174                 int rv = sk.read(buf);
175                 if(rv < 0) {
176                     eof = true;
177                 } else if(rv > 0) {
178                     lastread = now;
179                     cur += rv;
180                 }
181                 buf.flip();
182                 while(buf.remaining() > 0)
183                     out.write(buf);
184             }
185             if(now > lastread + timeout) {
186                 eof = true;
187                 buf.position(buf.limit());
188             }
189         }
190
191         public void close() {
192             try {
193                 out.close();
194             } catch(IOException e) {
195                 log.log(Level.WARNING, "failed to close transfer channel: " + out, e);
196             } finally {
197                 super.close();
198             }
199         }
200     }
201
202     protected void respond(Request req, String status, Map resp) {
203         Object output = resp.get("jagi.output");
204         ByteArrayOutputStream buf = new ByteArrayOutputStream();
205         try {
206             Writer head = new OutputStreamWriter(buf, Utils.UTF8);
207             head.write("Status: ");
208             head.write(status);
209             head.write("\n");
210             for(Iterator it = resp.entrySet().iterator(); it.hasNext();) {
211                 Map.Entry ent = (Map.Entry)it.next();
212                 Object val = ent.getValue();
213                 if((ent.getKey() instanceof String) && (val != null)) {
214                     String key = (String)ent.getKey();
215                     if(key.startsWith("http.")) {
216                         String nm = key.substring(5);
217                         if(nm.equalsIgnoreCase("status"))
218                             continue;
219                         if(val instanceof Collection) {
220                             for(Object part : (Collection)val) {
221                                 head.write(nm);
222                                 head.write(": ");
223                                 head.write(part.toString());
224                                 head.write("\n");
225                             }
226                         } else {
227                             head.write(nm);
228                             head.write(": ");
229                             head.write(val.toString());
230                             head.write("\n");
231                         }
232                     }
233                 }
234             }
235             head.write("\n");
236             head.flush();
237         } catch(IOException e) {
238             throw(new RuntimeException("cannot happen"));
239         }
240         ChainWatcher out;
241         if(output == null) {
242             out = new BufferedOutput(req.sk, ByteBuffer.allocate(0));
243         } else if(output instanceof byte[]) {
244             out = new BufferedOutput(req.sk, ByteBuffer.wrap((byte[])output));
245         } else if(output instanceof ByteBuffer) {
246             out = new BufferedOutput(req.sk, (ByteBuffer)output);
247         } else if(output instanceof String) {
248             out = new BufferedOutput(req.sk, ByteBuffer.wrap(((String)output).getBytes(Utils.UTF8)));
249         } else if(output instanceof CharSequence) {
250             out = new BufferedOutput(req.sk, Utils.UTF8.encode(CharBuffer.wrap((CharSequence)output)));
251         } else if(output instanceof InputStream) {
252             out = new TransferOutput(req.sk, Channels.newChannel((InputStream)output));
253         } else if(output instanceof ReadableByteChannel) {
254             out = new TransferOutput(req.sk, (ReadableByteChannel)output);
255         } else {
256             throw(new IllegalArgumentException("response-body: " + output));
257         }
258         out.then(() -> submit(req::close));
259         ev.add(new BufferedOutput(req.sk, ByteBuffer.wrap(buf.toByteArray())).then(() -> ev.add(out)));
260     }
261
262     @SuppressWarnings("unchecked")
263     protected void handle(Request req, Function handler) {
264         boolean handoff = false;
265         try {
266             Throwable error = null;
267             try {
268                 Map resp = (Map)handler.apply(req.env);
269                 String st;
270                 if((st = (String)resp.get("jagi.status")) != null) {
271                     Function next = (Function)resp.get("jagi.next");
272                     switch(st) {
273                     case "feed-input":
274                         Object sink = resp.get("jagi.input-sink");
275                         Object clen = req.env.get("HTTP_CONTENT_LENGTH");
276                         long max = 0;
277                         if(clen instanceof String) {
278                             try {
279                                 max = Long.parseLong((String)clen);
280                             } catch(NumberFormatException e) {
281                             }
282                         }
283                         if(sink instanceof WritableByteChannel) {
284                             ev.add(new TransferInput(req.sk, (WritableByteChannel)sink, max).then(() -> submit(() -> handle(req, next))));
285                         } else if(sink instanceof OutputStream) {
286                             ev.add(new TransferInput(req.sk, Channels.newChannel((OutputStream)sink), max).then(() -> submit(() -> handle(req, next))));
287                         } else {
288                             throw(new IllegalArgumentException("input-sink: " + sink));
289                         }
290                         handoff = true;
291                         break;
292                     case "chain":
293                         submit(() -> handle(req, next));
294                         handoff = true;
295                         break;
296                     default:
297                         throw(new IllegalArgumentException("jagi.status: " + st));
298                     }
299                 } else if((st = (String)resp.get("http.status")) != null) {
300                     respond(req, st, resp);
301                     handoff = true;
302                 } else {
303                     throw(new IllegalArgumentException("neither http.status nor jagi.status set"));
304                 }
305             } catch(Throwable t) {
306                 error = t;
307                 throw(t);
308             } finally {
309                 if(!handoff) {
310                     try {
311                         req.close();
312                     } catch(Throwable ce) {
313                         if(error == null) {
314                             throw(ce);
315                         } else {
316                             error.addSuppressed(ce);
317                         }
318                     }
319                 }
320             }
321         } catch(Throwable t) {
322             error(req, t);
323         }
324     }
325
326     protected void submit(Runnable task) {
327         handlers.submit(task);
328     }
329
330     class Client implements Watcher {
331         final SocketChannel sk;
332         double lastread;
333         boolean eof = false, handoff = false;
334         int headlen = 0;
335         ByteBuffer head = null;
336         Map<Object, Object> env = null;
337         Request req = null;
338
339         Client(SocketChannel sk) {
340             this.sk = sk;
341         }
342
343         public void added(Driver d) {lastread = d.time();}
344         public SelectableChannel channel() {return(sk);}
345         public double timeout() {return(lastread + timeout);}
346         public int events() {
347             if(eof)
348                 return(-1);
349             if(env == null)
350                 return(SelectionKey.OP_READ);
351             return(-1);
352         }
353
354         boolean readhead() throws IOException {
355             if(head == null) {
356                 ByteBuffer buf = ByteBuffer.allocate(1);
357                 while(true) {
358                     buf.rewind();
359                     int rv = sk.read(buf);
360                     if(rv < 0) {
361                         eof = true;
362                         return(false);
363                     } else if(rv == 0) {
364                         return(false);
365                     } else {
366                         lastread = Driver.current().time();
367                         int c = buf.get(0);
368                         if((c >= '0') && (c <= '9')) {
369                             headlen = (headlen * 10) + (c - '0');
370                         } else if(c == ':') {
371                             head = ByteBuffer.allocate(headlen + 1);
372                             break;
373                         } else {
374                             eof = true;
375                             return(false);
376                         }
377                     }
378                 }
379             }
380             while(true) {
381                 if(head.remaining() == 0) {
382                     if(head.get(head.limit() - 1) != ',') {
383                         /* Unterminated netstring */
384                         eof = true;
385                         return(false);
386                     }
387                     head.limit(head.limit() - 1);
388                     env = Jagi.mkenv(Scgi.splithead(head), sk);
389                     return(true);
390                 }
391                 int rv = sk.read(head);
392                 if(rv < 0) {
393                     eof = true;
394                     return(false);
395                 } else if(rv == 0) {
396                     return(false);
397                 }
398             }
399         }
400
401         public void handle(int events) throws IOException {
402             if((events & SelectionKey.OP_READ) != 0) {
403                 if((env == null) && !readhead())
404                     return;
405                 req = new Request(env, sk);
406                 handoff = true;
407             }
408             if(Driver.current().time() > (lastread + timeout))
409                 eof = true;
410         }
411
412         public void close() {
413             if(req != null)
414                 submit(() -> EventServer.this.handle(req, handler));
415             if(!handoff) {
416                 try {
417                     sk.close();
418                 } catch(IOException e) {
419                 }
420             }
421         }
422     }
423
424     class Accepter implements Watcher {
425         boolean closed = false;
426
427         public SelectableChannel channel() {return(sk);}
428         public int events() {return(SelectionKey.OP_ACCEPT);}
429
430         public void handle(int events) throws IOException {
431             if((events & SelectionKey.OP_ACCEPT) != 0)
432                 Driver.current().add(new Client(sk.accept()));
433         }
434
435         public void close() {
436             synchronized(this) {
437                 closed = true;
438                 notifyAll();
439             }
440         }
441     }
442
443     public void run() {
444         Accepter main = new Accepter();
445         ev.add(main);
446         try {
447             synchronized(main) {
448                 while(!main.closed) {
449                     main.wait();
450                 }
451             }
452         } catch(InterruptedException e) {
453             ev.remove(main);
454         } finally {
455             try {
456                 sk.close();
457             } catch(IOException e) {
458                 throw(new RuntimeException(e));
459             }
460         }
461     }
462 }