diff --git a/src/wstationd.c b/src/wstationd.c index c8879ca..ebcb23d 100644 --- a/src/wstationd.c +++ b/src/wstationd.c @@ -17,6 +17,7 @@ #include #include #include +#include #include "wstationd.h" @@ -27,16 +28,18 @@ // dictates the memory cost per connection. #define BUFFERSIZE 65539 -// Hard limit on number of concurrent connections -#define SLOTLIMIT 128 - // Timeout for each connection, in seconds #define TIMEOUT 30 static char *exec_name; +// Should this application daemonize? static int daemonize = 1; +// Hard limit on number of concurrent connections +static int slotlimit = 128; +// Used to handle signaled shutdowns +static int keep_polling = 1; struct connection { @@ -50,10 +53,16 @@ struct connection static int write_connection(struct connection *); +void shutdown_handler(int sig) +{ + keep_polling = 0; +} + static void print_help() { - fprintf(stdout, "usage: %s [-p port] [-b addr] directory\n", exec_name); + fprintf(stdout, "usage: %s [-d] [-l limit] [-p port] [-b addr] directory\n", exec_name); fprintf(stdout, " -d, --nodaemon Do not detach and daemonize\n"); + fprintf(stdout, " -l, --limit Set concurrent connection limit (default: 128)\n"); fprintf(stdout, " -p, --port Port to listen on (default: 10800)\n"); fprintf(stdout, " -b, --bind
Address to bind to (default: 0.0.0.0)\n"); fprintf(stdout, " -h, --help Print this message and exit\n"); @@ -64,6 +73,7 @@ static void print_help() int main(int argc, char* argv[]) { const struct option longopts[] = { + { "limit", required_argument, NULL, 'l' }, { "nodaemon", required_argument, NULL, 'd' }, { "port", required_argument, NULL, 'p' }, { "bind", required_argument, NULL, 'b' }, @@ -77,8 +87,8 @@ int main(int argc, char* argv[]) int port = 10800; // Default port uint32_t address = INADDR_ANY; // Default listen address - struct connection* conns[SLOTLIMIT]; - memset(conns, 0, sizeof(struct connection*) * SLOTLIMIT); + struct connection** conns = NULL; + struct pollfd* fds = NULL; if((exec_name = basename(argv[0])) == NULL){ fprintf(stderr, "%s: cannot get basename - %s\n", @@ -92,6 +102,15 @@ int main(int argc, char* argv[]) case 'd': daemonize = 0; break; + + case 'l': + slotlimit = (int)strtol(optarg, NULL, 10); + if(slotlimit < 1 || slotlimit > 1024){ + fprintf(stderr, "%s: invalid limit (must be between 1 and 1024)\n", exec_name); + print_help(); + goto shutdown_error; + } + break; case 'p': port = (int)strtol(optarg, NULL, 10); @@ -135,6 +154,9 @@ int main(int argc, char* argv[]) goto shutdown_error; } + // Handle signals + signal(SIGINT, shutdown_handler); + // Daemonize if requested if(daemonize){ pid_t pid = fork(); @@ -228,31 +250,41 @@ int main(int argc, char* argv[]) goto shutdown_error; } - struct pollfd fds[SLOTLIMIT+1]; - memset(&fds, 0, sizeof(struct pollfd) * (SLOTLIMIT + 1)); + fds = malloc(sizeof(struct pollfd) * (slotlimit + 1)); + conns = malloc(sizeof(struct connection*) * slotlimit); + if(conns == NULL || fds == NULL){ + if(daemonize){ syslog(LOG_ERR, "out of memory"); } + else { fprintf(stderr, "%s: out of memory\n", exec_name); } + } + + memset(conns, 0, sizeof(struct connection*) * slotlimit); + memset(fds, 0, sizeof(struct pollfd) * (slotlimit + 1)); // Add sockfd fds[0].fd = sockfd; fds[0].events = POLLIN; fds[0].revents = 0; - while(1){ - if(poll(fds, (SLOTLIMIT + 1), 1000) == -1){ - if(daemonize){ - syslog(LOG_ERR, "socket poll error - %s", - strerror(errno) - ); - } else { - fprintf(stderr, "%s: socket poll error - %s\n", - exec_name, strerror(errno) - ); + while(keep_polling){ + if(poll(fds, (slotlimit + 1), 1000) == -1){ + // Ignore interrupt generated errors + if(errno != EINTR){ + if(daemonize){ + syslog(LOG_ERR, "socket poll error - %s", + strerror(errno) + ); + } else { + fprintf(stderr, "%s: socket poll error - %s\n", + exec_name, strerror(errno) + ); + } + goto shutdown_error; } - goto shutdown_error; } // Handle existing connections int slotfirst = 0, slotfree = 0; - for(int i = 1; i <= SLOTLIMIT; i++){ + for(int i = 1; i <= slotlimit; i++){ // Handle event if(fds[i].revents & POLLIN){ ssize_t sz = recv( @@ -390,12 +422,24 @@ int main(int argc, char* argv[]) shutdown_clean: if(sockfd > 0) close(sockfd); - for(int i = 0; i < SLOTLIMIT; i++){ if(conns[i] != NULL) free(conns[i]); } + if(fds != NULL){ free(fds); fds = NULL; } + if(conns != NULL){ + for(int i = 0; i < slotlimit; i++){ + if(conns[i] != NULL) free(conns[i]); + } + free(conns); conns = NULL; + } return 0; shutdown_error: if(sockfd > 0) close(sockfd); - for(int i = 0; i < SLOTLIMIT; i++){ if(conns[i] != NULL) free(conns[i]); } + if(fds != NULL){ free(fds); fds = NULL; } + if(conns != NULL){ + for(int i = 0; i < slotlimit; i++){ + if(conns[i] != NULL) free(conns[i]); + } + free(conns); conns = NULL; + } return 1; }