Fixed some TransferOutput bugs.
[jagi.git] / src / jagi / scgi / EventServer.java
1 package jagi.scgi;
2
3 import jagi.*;
4 import jagi.event.*;
5 import java.util.*;
6 import java.util.function.*;
7 import java.util.concurrent.*;
8 import java.util.logging.*;
9 import java.io.*;
10 import java.nio.*;
11 import java.nio.channels.*;
12
13 public class EventServer implements Runnable {
14     private static final double timeout = 5;
15     private static final Logger log = Logger.getLogger("jagi.server");
16     private final ServerSocketChannel sk;
17     private final Function handler;
18     private final Driver ev = Driver.get();
19     private final ExecutorService handlers = new ThreadPoolExecutor(0, Runtime.getRuntime().availableProcessors() * 2,
20                                                                     5, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>(64),
21                                                                     tgt -> new Thread(tgt, "Request handler thread"));
22
23     public EventServer(ServerSocketChannel sk, Function handler) {
24         this.sk = sk;
25         this.handler = handler;
26     }
27
28     public static class Request {
29         public final Map<Object, Object> env;
30         public final SocketChannel sk;
31
32         public Request(Map<Object, Object> env, SocketChannel sk) {
33             this.env = env;
34             this.sk = sk;
35         }
36
37         public void close() {
38             ArrayList<Object> cleanup = new ArrayList<>((Collection<?>)env.get("jagi.cleanup"));
39             cleanup.add(sk);
40             RuntimeException ce = null;
41             for(Object obj : cleanup) {
42                 if(obj instanceof AutoCloseable) {
43                     try {
44                         ((AutoCloseable)obj).close();
45                     } catch(Exception e) {
46                         if(ce == null)
47                             ce = new RuntimeException("error(s) occurred during cleanup");
48                         ce.addSuppressed(e);
49                     }
50                 }
51             }
52             if(ce != null)
53                 throw(ce);
54         }
55     }
56
57     protected void error(Request req, Throwable error) {
58         log.log(Level.WARNING, "uncaught exception while handling request", error);
59     }
60
61     public static abstract class ChainWatcher implements Watcher {
62         private Runnable then;
63         public ChainWatcher then(Runnable then) {this.then = then; return(this);}
64
65         public void close() {
66             if(then != null)
67                 then.run();
68         }
69     }
70
71     public static class BufferedOutput extends ChainWatcher {
72         public final SocketChannel sk;
73         public final ByteBuffer buf;
74         private double lastwrite;
75
76         public BufferedOutput(SocketChannel sk, ByteBuffer buf) {
77             this.sk = sk;
78             this.buf = buf;
79         }
80
81         public void added(Driver d) {lastwrite = d.time();}
82         public SelectableChannel channel() {return(sk);}
83         public int events() {return((buf.remaining() > 0) ? SelectionKey.OP_WRITE : -1);}
84         public double timeout() {return(lastwrite + timeout);}
85
86         public void handle(int events) throws IOException {
87             double now = Driver.current().time();
88             if((events & SelectionKey.OP_WRITE) != 0) {
89                 if(sk.write(buf) > 0)
90                     lastwrite = now;
91             }
92             if(now > lastwrite + timeout)
93                 buf.position(buf.limit());
94         }
95     }
96
97     public static class TransferOutput extends ChainWatcher {
98         public final SocketChannel sk;
99         public final ReadableByteChannel in;
100         private final ByteBuffer buf;
101         private boolean eof = false;
102         private double lastwrite;
103
104         public TransferOutput(SocketChannel sk, ReadableByteChannel in) {
105             this.sk = sk;
106             this.in = in;
107             buf = ByteBuffer.allocate(65536);
108             buf.flip();
109         }
110
111         public void added(Driver d) {lastwrite = d.time();}
112         public SelectableChannel channel() {return(sk);}
113         public int events() {return((eof && (buf.remaining() == 0)) ? -1 : SelectionKey.OP_WRITE);}
114         public double timeout() {return(lastwrite + timeout);}
115
116         public void handle(int events) throws IOException {
117             if(!eof && (buf.remaining() == 0)) {
118                 buf.clear();
119                 while(buf.remaining() > 0) {
120                     if(in.read(buf) < 0) {
121                         eof = true;
122                         break;
123                     }
124                 }
125                 buf.flip();
126             }
127             double now = Driver.current().time();
128             if((events & SelectionKey.OP_WRITE) != 0) {
129                 if(sk.write(buf) > 0)
130                     lastwrite = now;
131             }
132             if(now > lastwrite + timeout) {
133                 eof = true;
134                 buf.position(buf.limit());
135             }
136         }
137
138         public void close() {
139             try {
140                 in.close();
141             } catch(IOException e) {
142                 log.log(Level.WARNING, "failed to close transfer channel: " + in, e);
143             } finally {
144                 super.close();
145             }
146         }
147     }
148
149     public static class TransferInput extends ChainWatcher {
150         public final SocketChannel sk;
151         public final WritableByteChannel out;
152         private final ByteBuffer buf;
153         private final long max;
154         private boolean eof = false;
155         private double lastread;
156         private long cur = 0;
157
158         public TransferInput(SocketChannel sk, WritableByteChannel out, long max) {
159             this.sk = sk;
160             this.out = out;
161             this.max = max;
162             buf = ByteBuffer.allocate(65536);
163             buf.flip();
164         }
165
166         public void added(Driver d) {lastread = d.time();}
167         public SelectableChannel channel() {return(sk);}
168         public int events() {return(eof ? -1 : SelectionKey.OP_READ);}
169         public double timeout() {return(lastread + timeout);}
170
171         public void handle(int events) throws IOException {
172             double now = Driver.current().time();
173             if((events & SelectionKey.OP_READ) != 0) {
174                 buf.clear();
175                 if(buf.remaining() > max - cur)
176                     buf.limit(buf.position() + (int)Math.min(max - cur, Integer.MAX_VALUE));
177                 int rv = sk.read(buf);
178                 if(rv < 0) {
179                     eof = true;
180                 } else if(rv > 0) {
181                     lastread = now;
182                     if((cur += rv) >= max)
183                         eof = true;
184                 }
185                 buf.flip();
186                 while(buf.remaining() > 0)
187                     out.write(buf);
188             }
189             if(now > lastread + timeout) {
190                 eof = true;
191                 buf.position(buf.limit());
192             }
193         }
194
195         public void close() {
196             try {
197                 out.close();
198             } catch(IOException e) {
199                 log.log(Level.WARNING, "failed to close transfer channel: " + out, e);
200             } finally {
201                 super.close();
202             }
203         }
204     }
205
206     protected void respond(Request req, String status, Map resp) {
207         Object output = resp.get("jagi.output");
208         ByteArrayOutputStream buf = new ByteArrayOutputStream();
209         try {
210             Writer head = new OutputStreamWriter(buf, Utils.UTF8);
211             head.write("Status: ");
212             head.write(status);
213             head.write("\n");
214             for(Iterator it = resp.entrySet().iterator(); it.hasNext();) {
215                 Map.Entry ent = (Map.Entry)it.next();
216                 Object val = ent.getValue();
217                 if((ent.getKey() instanceof String) && (val != null)) {
218                     String key = (String)ent.getKey();
219                     if(key.startsWith("http.")) {
220                         String nm = key.substring(5);
221                         if(nm.equalsIgnoreCase("status"))
222                             continue;
223                         if(val instanceof Collection) {
224                             for(Object part : (Collection)val) {
225                                 head.write(nm);
226                                 head.write(": ");
227                                 head.write(part.toString());
228                                 head.write("\n");
229                             }
230                         } else {
231                             head.write(nm);
232                             head.write(": ");
233                             head.write(val.toString());
234                             head.write("\n");
235                         }
236                     }
237                 }
238             }
239             head.write("\n");
240             head.flush();
241         } catch(IOException e) {
242             throw(new RuntimeException("cannot happen"));
243         }
244         ChainWatcher out;
245         if(output == null) {
246             out = new BufferedOutput(req.sk, ByteBuffer.allocate(0));
247         } else if(output instanceof byte[]) {
248             out = new BufferedOutput(req.sk, ByteBuffer.wrap((byte[])output));
249         } else if(output instanceof ByteBuffer) {
250             out = new BufferedOutput(req.sk, (ByteBuffer)output);
251         } else if(output instanceof String) {
252             out = new BufferedOutput(req.sk, ByteBuffer.wrap(((String)output).getBytes(Utils.UTF8)));
253         } else if(output instanceof CharSequence) {
254             out = new BufferedOutput(req.sk, Utils.UTF8.encode(CharBuffer.wrap((CharSequence)output)));
255         } else if(output instanceof InputStream) {
256             out = new TransferOutput(req.sk, Channels.newChannel((InputStream)output));
257         } else if(output instanceof ReadableByteChannel) {
258             out = new TransferOutput(req.sk, (ReadableByteChannel)output);
259         } else {
260             throw(new IllegalArgumentException("response-body: " + output));
261         }
262         out.then(() -> submit(req::close));
263         ev.add(new BufferedOutput(req.sk, ByteBuffer.wrap(buf.toByteArray())).then(() -> ev.add(out)));
264     }
265
266     @SuppressWarnings("unchecked")
267     protected void handle(Request req, Function handler) {
268         boolean handoff = false;
269         try {
270             Throwable error = null;
271             try {
272                 Map resp = (Map)handler.apply(req.env);
273                 String st;
274                 if((st = (String)resp.get("jagi.status")) != null) {
275                     Function next = (Function)resp.get("jagi.next");
276                     switch(st) {
277                     case "feed-input":
278                         Object sink = resp.get("jagi.input-sink");
279                         Object clen = req.env.get("HTTP_CONTENT_LENGTH");
280                         long max = 0;
281                         if(clen instanceof String) {
282                             try {
283                                 max = Long.parseLong((String)clen);
284                             } catch(NumberFormatException e) {
285                             }
286                         }
287                         if(sink instanceof WritableByteChannel) {
288                             ev.add(new TransferInput(req.sk, (WritableByteChannel)sink, max).then(() -> submit(() -> handle(req, next))));
289                         } else if(sink instanceof OutputStream) {
290                             ev.add(new TransferInput(req.sk, Channels.newChannel((OutputStream)sink), max).then(() -> submit(() -> handle(req, next))));
291                         } else {
292                             throw(new IllegalArgumentException("input-sink: " + sink));
293                         }
294                         handoff = true;
295                         break;
296                     case "chain":
297                         submit(() -> handle(req, next));
298                         handoff = true;
299                         break;
300                     default:
301                         throw(new IllegalArgumentException("jagi.status: " + st));
302                     }
303                 } else if((st = (String)resp.get("http.status")) != null) {
304                     respond(req, st, resp);
305                     handoff = true;
306                 } else {
307                     throw(new IllegalArgumentException("neither http.status nor jagi.status set"));
308                 }
309             } catch(Throwable t) {
310                 error = t;
311                 throw(t);
312             } finally {
313                 if(!handoff) {
314                     try {
315                         req.close();
316                     } catch(Throwable ce) {
317                         if(error == null) {
318                             throw(ce);
319                         } else {
320                             error.addSuppressed(ce);
321                         }
322                     }
323                 }
324             }
325         } catch(Throwable t) {
326             error(req, t);
327         }
328     }
329
330     protected void submit(Runnable task) {
331         handlers.submit(task);
332     }
333
334     class Client implements Watcher {
335         final SocketChannel sk;
336         double lastread;
337         boolean eof = false, handoff = false;
338         int headlen = 0;
339         ByteBuffer head = null;
340         Map<Object, Object> env = null;
341         Request req = null;
342
343         Client(SocketChannel sk) {
344             this.sk = sk;
345         }
346
347         public void added(Driver d) {lastread = d.time();}
348         public SelectableChannel channel() {return(sk);}
349         public double timeout() {return(lastread + timeout);}
350         public int events() {
351             if(eof)
352                 return(-1);
353             if(env == null)
354                 return(SelectionKey.OP_READ);
355             return(-1);
356         }
357
358         boolean readhead() throws IOException {
359             if(head == null) {
360                 ByteBuffer buf = ByteBuffer.allocate(1);
361                 while(true) {
362                     buf.rewind();
363                     int rv = sk.read(buf);
364                     if(rv < 0) {
365                         eof = true;
366                         return(false);
367                     } else if(rv == 0) {
368                         return(false);
369                     } else {
370                         lastread = Driver.current().time();
371                         int c = buf.get(0);
372                         if((c >= '0') && (c <= '9')) {
373                             headlen = (headlen * 10) + (c - '0');
374                         } else if(c == ':') {
375                             head = ByteBuffer.allocate(headlen + 1);
376                             break;
377                         } else {
378                             eof = true;
379                             return(false);
380                         }
381                     }
382                 }
383             }
384             while(true) {
385                 if(head.remaining() == 0) {
386                     if(head.get(head.limit() - 1) != ',') {
387                         /* Unterminated netstring */
388                         eof = true;
389                         return(false);
390                     }
391                     head.limit(head.limit() - 1);
392                     env = Jagi.mkenv(Scgi.splithead(head), sk);
393                     return(true);
394                 }
395                 int rv = sk.read(head);
396                 if(rv < 0) {
397                     eof = true;
398                     return(false);
399                 } else if(rv == 0) {
400                     return(false);
401                 }
402             }
403         }
404
405         public void handle(int events) throws IOException {
406             if((events & SelectionKey.OP_READ) != 0) {
407                 if((env == null) && !readhead())
408                     return;
409                 req = new Request(env, sk);
410                 handoff = true;
411             }
412             if(Driver.current().time() > (lastread + timeout))
413                 eof = true;
414         }
415
416         public void close() {
417             if(req != null)
418                 submit(() -> EventServer.this.handle(req, handler));
419             if(!handoff) {
420                 try {
421                     sk.close();
422                 } catch(IOException e) {
423                 }
424             }
425         }
426     }
427
428     class Accepter implements Watcher {
429         boolean closed = false;
430
431         public SelectableChannel channel() {return(sk);}
432         public int events() {return(SelectionKey.OP_ACCEPT);}
433
434         public void handle(int events) throws IOException {
435             if((events & SelectionKey.OP_ACCEPT) != 0)
436                 Driver.current().add(new Client(sk.accept()));
437         }
438
439         public void close() {
440             synchronized(this) {
441                 closed = true;
442                 notifyAll();
443             }
444         }
445     }
446
447     public void run() {
448         Accepter main = new Accepter();
449         ev.add(main);
450         try {
451             synchronized(main) {
452                 while(!main.closed) {
453                     main.wait();
454                 }
455             }
456         } catch(InterruptedException e) {
457             ev.remove(main);
458         } finally {
459             try {
460                 sk.close();
461             } catch(IOException e) {
462                 throw(new RuntimeException(e));
463             }
464         }
465     }
466 }