Added event-driven server.
authorFredrik Tolf <fredrik@dolda2000.com>
Thu, 17 Feb 2022 12:51:50 +0000 (13:51 +0100)
committerFredrik Tolf <fredrik@dolda2000.com>
Thu, 17 Feb 2022 12:51:50 +0000 (13:51 +0100)
src/jagi/scgi/Bootstrap.java
src/jagi/scgi/EventServer.java [new file with mode: 0644]
src/jagi/scgi/Jagi.java
src/jagi/scgi/Scgi.java

index accac96..c2d4fbd 100644 (file)
@@ -159,7 +159,7 @@ public class Bootstrap {
        } else {
            sk = getstdin();
        }
-       Runnable server = new SimpleServer(sk, handler);
+       Runnable server = new EventServer(sk, handler);
        try {
            server.run();
        } catch(Throwable e) {
diff --git a/src/jagi/scgi/EventServer.java b/src/jagi/scgi/EventServer.java
new file mode 100644 (file)
index 0000000..3df6c03
--- /dev/null
@@ -0,0 +1,447 @@
+package jagi.scgi;
+
+import jagi.*;
+import jagi.event.*;
+import java.util.*;
+import java.util.function.*;
+import java.util.concurrent.*;
+import java.util.logging.*;
+import java.io.*;
+import java.nio.*;
+import java.nio.channels.*;
+
+public class EventServer implements Runnable {
+    private static final double timeout = 5;
+    private static final Logger log = Logger.getLogger("jagi.server");
+    private final ServerSocketChannel sk;
+    private final Function handler;
+    private final Driver ev = Driver.get();
+    private final ExecutorService handlers = new ThreadPoolExecutor(0, Runtime.getRuntime().availableProcessors() * 2,
+                                                                   5, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>(64),
+                                                                   tgt -> new Thread(tgt, "Request handler thread"));
+
+    public EventServer(ServerSocketChannel sk, Function handler) {
+       try {
+           sk.configureBlocking(false);
+       } catch(IOException e) {
+           throw(new RuntimeException(e));
+       }
+       this.sk = sk;
+       this.handler = handler;
+    }
+
+    public static class Request {
+       public final Map<Object, Object> env;
+       public final SocketChannel sk;
+
+       public Request(Map<Object, Object> env, SocketChannel sk) {
+           this.env = env;
+           this.sk = sk;
+       }
+
+       public void close() {
+           ArrayList<Object> cleanup = new ArrayList<>((Collection<?>)env.get("jagi.cleanup"));
+           cleanup.add(sk);
+           RuntimeException ce = null;
+           for(Object obj : cleanup) {
+               if(obj instanceof AutoCloseable) {
+                   try {
+                       ((AutoCloseable)obj).close();
+                   } catch(Exception e) {
+                       if(ce == null)
+                           ce = new RuntimeException("error(s) occurred during cleanup");
+                       ce.addSuppressed(e);
+                   }
+               }
+           }
+           if(ce != null)
+               throw(ce);
+       }
+    }
+
+    protected void error(Request req, Throwable error) {
+       log.log(Level.WARNING, "uncaught exception while handling request", error);
+    }
+
+    public static abstract class ChainWatcher implements Watcher {
+       private Runnable then;
+       public ChainWatcher then(Runnable then) {this.then = then; return(this);}
+
+       public void close() {
+           if(then != null)
+               then.run();
+       }
+    }
+
+    public static class BufferedOutput extends ChainWatcher {
+       public final SocketChannel sk;
+       public final ByteBuffer buf;
+       private double lastwrite;
+
+       public BufferedOutput(SocketChannel sk, ByteBuffer buf) {
+           this.sk = sk;
+           this.buf = buf;
+       }
+
+       public void added(Driver d) {lastwrite = d.time();}
+       public SelectableChannel channel() {return(sk);}
+       public int events() {return((buf.remaining() > 0) ? SelectionKey.OP_WRITE : -1);}
+       public double timeout() {return(lastwrite + timeout);}
+
+       public void handle(int events) throws IOException {
+           double now = Driver.current().time();
+           if((events & SelectionKey.OP_WRITE) != 0) {
+               if(sk.write(buf) > 0)
+                   lastwrite = now;
+           }
+           if(now > lastwrite + timeout)
+               buf.position(buf.limit());
+       }
+    }
+
+    public static class TransferOutput extends ChainWatcher {
+       public final SocketChannel sk;
+       public final ReadableByteChannel in;
+       private final ByteBuffer buf;
+       private boolean eof = false;
+       private double lastwrite;
+
+       public TransferOutput(SocketChannel sk, ReadableByteChannel in) {
+           this.sk = sk;
+           this.in = in;
+           buf = ByteBuffer.allocate(65536);
+           buf.flip();
+       }
+
+       public void added(Driver d) {lastwrite = d.time();}
+       public SelectableChannel channel() {return(sk);}
+       public int events() {return((eof && (buf.remaining() == 0)) ? -1 : SelectionKey.OP_WRITE);}
+       public double timeout() {return(lastwrite + timeout);}
+
+       public void handle(int events) throws IOException {
+           if(!eof && (buf.remaining() == 0)) {
+               buf.rewind();
+               while(buf.remaining() > 0) {
+                   if(in.read(buf) < 0)
+                       break;
+               }
+           }
+           double now = Driver.current().time();
+           if((events & SelectionKey.OP_WRITE) != 0) {
+               if(sk.write(buf) > 0)
+                   lastwrite = now;
+           }
+           if(now > lastwrite + timeout) {
+               eof = true;
+               buf.position(buf.limit());
+           }
+       }
+
+       public void close() {
+           try {
+               in.close();
+           } catch(IOException e) {
+               log.log(Level.WARNING, "failed to close transfer channel: " + in, e);
+           } finally {
+               super.close();
+           }
+       }
+    }
+
+    public static class TransferInput extends ChainWatcher {
+       public final SocketChannel sk;
+       public final WritableByteChannel out;
+       private final ByteBuffer buf;
+       private boolean eof = false;
+       private double lastread;
+
+       public TransferInput(SocketChannel sk, WritableByteChannel out) {
+           this.sk = sk;
+           this.out = out;
+           buf = ByteBuffer.allocate(65536);
+           buf.flip();
+       }
+
+       public void added(Driver d) {lastread = d.time();}
+       public SelectableChannel channel() {return(sk);}
+       public int events() {return(eof ? -1 : SelectionKey.OP_READ);}
+       public double timeout() {return(lastread + timeout);}
+
+       public void handle(int events) throws IOException {
+           double now = Driver.current().time();
+           if((events & SelectionKey.OP_READ) != 0) {
+               buf.rewind();
+               int rv = sk.read(buf);
+               if(rv < 0)
+                   eof = true;
+               else if(rv > 0)
+                   lastread = now;
+               buf.flip();
+               while(buf.remaining() > 0)
+                   out.write(buf);
+           }
+           if(now > lastread + timeout) {
+               eof = true;
+               buf.position(buf.limit());
+           }
+       }
+
+       public void close() {
+           try {
+               out.close();
+           } catch(IOException e) {
+               log.log(Level.WARNING, "failed to close transfer channel: " + out, e);
+           } finally {
+               super.close();
+           }
+       }
+    }
+
+    protected void respond(Request req, String status, Map resp) {
+       Object output = resp.get("jagi.output");
+       ByteArrayOutputStream buf = new ByteArrayOutputStream();
+       try {
+           Writer head = new OutputStreamWriter(buf, Utils.UTF8);
+           head.write("Status: ");
+           head.write(status);
+           head.write("\n");
+           for(Iterator it = resp.entrySet().iterator(); it.hasNext();) {
+               Map.Entry ent = (Map.Entry)it.next();
+               Object val = ent.getValue();
+               if((ent.getKey() instanceof String) && (val != null)) {
+                   String key = (String)ent.getKey();
+                   if(key.startsWith("http.")) {
+                       String nm = key.substring(5);
+                       if(nm.equalsIgnoreCase("status"))
+                           continue;
+                       if(val instanceof Collection) {
+                           for(Object part : (Collection)val) {
+                               head.write(nm);
+                               head.write(": ");
+                               head.write(part.toString());
+                               head.write("\n");
+                           }
+                       } else {
+                           head.write(nm);
+                           head.write(": ");
+                           head.write(val.toString());
+                           head.write("\n");
+                       }
+                   }
+               }
+           }
+           head.write("\n");
+           head.flush();
+       } catch(IOException e) {
+           throw(new RuntimeException("cannot happen"));
+       }
+       ChainWatcher out;
+       if(output == null) {
+           out = new BufferedOutput(req.sk, ByteBuffer.allocate(0));
+       } else if(output instanceof byte[]) {
+           out = new BufferedOutput(req.sk, ByteBuffer.wrap((byte[])output));
+       } else if(output instanceof ByteBuffer) {
+           out = new BufferedOutput(req.sk, (ByteBuffer)output);
+       } else if(output instanceof String) {
+           out = new BufferedOutput(req.sk, ByteBuffer.wrap(((String)output).getBytes(Utils.UTF8)));
+       } else if(output instanceof CharSequence) {
+           out = new BufferedOutput(req.sk, Utils.UTF8.encode(CharBuffer.wrap((CharSequence)output)));
+       } else if(output instanceof InputStream) {
+           out = new TransferOutput(req.sk, Channels.newChannel((InputStream)output));
+       } else if(output instanceof ReadableByteChannel) {
+           out = new TransferOutput(req.sk, (ReadableByteChannel)output);
+       } else {
+           throw(new IllegalArgumentException("response-body: " + output));
+       }
+       out.then(() -> submit(req::close));
+       ev.add(new BufferedOutput(req.sk, ByteBuffer.wrap(buf.toByteArray())).then(() -> ev.add(out)));
+    }
+
+    @SuppressWarnings("unchecked")
+    protected void handle(Request req, Function handler) {
+       boolean handoff = false;
+       try {
+           Throwable error = null;
+           try {
+               Map resp = (Map)handler.apply(req.env);
+               String st;
+               if((st = (String)resp.get("jagi.status")) != null) {
+                   Function next = (Function)resp.get("jagi.next");
+                   switch(st) {
+                   case "feed-input":
+                       Object sink = resp.get("jagi.input-sink");
+                       if(sink instanceof WritableByteChannel) {
+                           ev.add(new TransferInput(req.sk, (WritableByteChannel)sink).then(() -> submit(() -> handle(req, next))));
+                       } else if(sink instanceof OutputStream) {
+                           ev.add(new TransferInput(req.sk, Channels.newChannel((OutputStream)sink)).then(() -> submit(() -> handle(req, next))));
+                       } else {
+                           throw(new IllegalArgumentException("input-sink: " + sink));
+                       }
+                       handoff = true;
+                       break;
+                   default:
+                       throw(new IllegalArgumentException("jagi.status: " + st));
+                   }
+               } else if((st = (String)resp.get("http.status")) != null) {
+                   respond(req, st, resp);
+                   handoff = true;
+               }
+           } catch(Throwable t) {
+               error = t;
+               throw(t);
+           } finally {
+               if(!handoff) {
+                   try {
+                       req.close();
+                   } catch(Throwable ce) {
+                       if(error == null) {
+                           throw(ce);
+                       } else {
+                           error.addSuppressed(ce);
+                       }
+                   }
+               }
+           }
+       } catch(Throwable t) {
+           error(req, t);
+       }
+    }
+
+    protected void submit(Runnable task) {
+       handlers.submit(task);
+    }
+
+    class Client implements Watcher {
+       final SocketChannel sk;
+       double lastread;
+       boolean eof = false, handoff = false;
+       int headlen = 0;
+       ByteBuffer head = null;
+       Map<Object, Object> env = null;
+
+       Client(SocketChannel sk) {
+           this.sk = sk;
+       }
+
+       public void added(Driver d) {lastread = d.time();}
+       public SelectableChannel channel() {return(sk);}
+       public double timeout() {return(lastread + timeout);}
+       public int events() {
+           if(eof)
+               return(-1);
+           if(env == null)
+               return(SelectionKey.OP_READ);
+           return(-1);
+       }
+
+       boolean readhead() throws IOException {
+           if(head == null) {
+               ByteBuffer buf = ByteBuffer.allocate(1);
+               while(true) {
+                   buf.rewind();
+                   int rv = sk.read(buf);
+                   if(rv < 0) {
+                       eof = true;
+                       return(false);
+                   } else if(rv == 0) {
+                       return(false);
+                   } else {
+                       lastread = Driver.current().time();
+                       int c = buf.get(0);
+                       if((c >= '0') && (c <= '9')) {
+                           headlen = (headlen * 10) + (c - '0');
+                       } else if(c == ':') {
+                           head = ByteBuffer.allocate(headlen + 1);
+                           break;
+                       } else {
+                           eof = true;
+                           return(false);
+                       }
+                   }
+               }
+           }
+           while(true) {
+               if(head.remaining() == 0) {
+                   if(head.get(head.limit() - 1) != ',') {
+                       /* Unterminated netstring */
+                       eof = true;
+                       return(false);
+                   }
+                   head.limit(head.limit() - 1);
+                   env = Jagi.mkenv(Scgi.splithead(head), sk);
+                   return(true);
+               }
+               int rv = sk.read(head);
+               if(rv < 0) {
+                   eof = true;
+                   return(false);
+               } else if(rv == 0) {
+                   return(false);
+               }
+           }
+       }
+
+       public void handle(int events) throws IOException {
+           if((events & SelectionKey.OP_READ) != 0) {
+               if((env == null) && !readhead())
+                   return;
+               Request req = new Request(env, sk);
+               submit(() -> EventServer.this.handle(req, handler));
+               handoff = true;
+           }
+           if(Driver.current().time() > (lastread + timeout))
+               eof = true;
+       }
+
+       public void close() {
+           if(!handoff) {
+               try {
+                   sk.close();
+               } catch(IOException e) {
+               }
+           }
+       }
+    }
+
+    class Accepter implements Watcher {
+       boolean closed = false;
+
+       public SelectableChannel channel() {return(sk);}
+       public int events() {return(SelectionKey.OP_ACCEPT);}
+
+       public void handle(int events) throws IOException {
+           if((events & SelectionKey.OP_ACCEPT) != 0) {
+               SocketChannel cl = sk.accept();
+               cl.configureBlocking(false);
+               Driver.current().add(new Client(cl));
+           }
+       }
+
+       public void close() {
+           synchronized(this) {
+               closed = true;
+               notifyAll();
+           }
+       }
+    }
+
+    public void run() {
+       Accepter main = new Accepter();
+       ev.add(main);
+       try {
+           synchronized(main) {
+               while(!main.closed) {
+                   main.wait();
+               }
+           }
+       } catch(InterruptedException e) {
+           ev.remove(main);
+       } finally {
+           try {
+               sk.close();
+           } catch(IOException e) {
+               throw(new RuntimeException(e));
+           }
+       }
+    }
+}
index 758afdc..7337f97 100644 (file)
@@ -13,8 +13,7 @@ public class Jagi {
            into.put(coding.newDecoder().decode(h.getKey()).toString(), coding.decode(h.getValue()).toString());
     }
 
-    public static Map<Object, Object> mkenv(ReadableByteChannel sk) throws IOException {
-       Map<ByteBuffer, ByteBuffer> rawhead = Scgi.readhead(sk);
+    public static Map<Object, Object> mkenv(Map<ByteBuffer, ByteBuffer> rawhead, ReadableByteChannel input) throws IOException {
        Map<Object, Object> env;
        try {
            env = new HashMap<>();
@@ -33,7 +32,7 @@ public class Jagi {
            env.put("jagi.url_scheme", "https");
        else
            env.put("jagi.url_scheme", "http");
-       env.put("jagi.input", sk);
+       env.put("jagi.input", input);
        env.put("jagi.errors", System.err);
        env.put("jagi.multithread", true);
        env.put("jagi.multiprocess", false);
@@ -41,4 +40,8 @@ public class Jagi {
        env.put("jagi.cleanup", new HashSet<>());
        return(env);
     }
+
+    public static Map<Object, Object> mkenv(ReadableByteChannel sk) throws IOException {
+       return(mkenv(Scgi.readhead(sk), sk));
+    }
 }
index 5e219b0..f7f57d2 100644 (file)
@@ -26,9 +26,8 @@ public class Scgi {
        return(data);
     }
 
-    public static Map<ByteBuffer, ByteBuffer> readhead(ReadableByteChannel sk) throws IOException {
+    public static Map<ByteBuffer, ByteBuffer> splithead(ByteBuffer ns) {
        Map<ByteBuffer, ByteBuffer> ret = new HashMap<>();
-       ByteBuffer ns = readns(sk);
        ByteBuffer k = null;
        for(int i = 0, p = 0; i < ns.limit(); i++) {
            if(ns.get(i) == 0) {
@@ -45,4 +44,8 @@ public class Scgi {
        }
        return(ret);
     }
+
+    public static Map<ByteBuffer, ByteBuffer> readhead(ReadableByteChannel sk) throws IOException {
+       return(splithead(readns(sk)));
+    }
 }