client: fix getting the response
[dyn-nsupdate.git] / nsupd-wrapper / dyn-nsupdate.cpp
1 /* Copyright (c) 2014, Ralf Jung <post@ralfj.de>
2 * All rights reserved.
3
4 * Redistribution and use in source and binary forms, with or without
5 * modification, are permitted provided that the following conditions are met:
6
7 * 1. Redistributions of source code must retain the above copyright notice, this
8 *    list of conditions and the following disclaimer. 
9 * 2. Redistributions in binary form must reproduce the above copyright notice,
10 *    this list of conditions and the following disclaimer in the documentation
11 *    and/or other materials provided with the distribution.
12
13 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
14 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
15 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
16 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
17 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
18 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
19 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
20 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
21 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
22 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 */
24
25 #include <iostream>
26 #include <fstream>
27 #include <sys/wait.h>
28
29 #include <boost/regex.hpp>
30 #include <boost/program_options.hpp>
31 #include <boost/property_tree/ptree.hpp>
32 #include <boost/property_tree/ini_parser.hpp>
33 #include <boost/iostreams/device/file_descriptor.hpp>
34 #include <boost/iostreams/stream.hpp>
35
36 namespace pt = boost::property_tree;
37 namespace po = boost::program_options;
38 using std::string;
39 using boost::regex;
40 using boost::optional;
41
42 static void write(int fd, const char *str)
43 {
44     size_t len = strlen(str);
45     ssize_t written = write(fd, str, len);
46     if (written < 0 || (size_t)written != len) {
47         std::cerr << "Error writing pipe." << std::endl;
48         exit(1);
49     }
50 }
51
52 int main(int argc, const char ** argv)
53 {
54     try {
55         // These regular expressions are not supposed to be fully precise: nsupdate will check the addresses, too.
56         // However, they have to make sure that there can be no injection attacks.
57 #define GROUP "[0-9]{1,3}"
58         static const regex regex_ipv4(GROUP "(\\." GROUP "){3}|");
59 #undef GROUP
60 #define GROUP "[a-fA-F0-9]{1,4}"
61         static const regex regex_ipv6("(" GROUP "(::?" GROUP "){0,6})?::?" GROUP "|");
62 #undef GROUP
63         
64         static const regex regex_password("[a-zA-Z0-9.:;,_-]+");
65         static const regex regex_domain("[a-zA-Z0-9.]+");
66         
67         // Declare the supported options.
68         po::options_description desc("Allowed options");
69         desc.add_options()
70             ("help", "produce help message")
71             ("domain", po::value<string>()->required(), "the domain to update")
72             ("password", po::value<string>()->required(), "the password for the domain")
73             ("ipv4", po::value<string>(), "the new IPv4 address (empty to delete the A record)")
74             ("ipv6", po::value<string>(), "the new IPv6 address (empty to delete the AAAA record)")
75         ;
76         
77         // parse arguments
78         po::variables_map vm;
79         po::store(po::parse_command_line(argc, argv, desc), vm);
80         po::notify(vm);    
81         if (vm.count("help")) {
82             std::cout << "dyn-nsupdate -- a safe setuid wrapper for nsupdate" << std::endl << std::endl;
83             std::cout << desc << "\n";
84             return 1;
85         }
86         string domain = vm["domain"].as<string>();
87         string password = vm["password"].as<string>();
88         bool haveIPv4 = vm.count("ipv4");
89         string ipv4 = haveIPv4 ? vm["ipv4"].as<string>() : "";
90         bool haveIPv6 = vm.count("ipv6");
91         string ipv6 = haveIPv6 ? vm["ipv6"].as<string>() : "";
92         
93         /* Validate input */
94         if (!regex_match(ipv4, regex_ipv4)) {
95             throw std::runtime_error("Invalid IPv4 address" + ipv4);
96         }
97         if (!regex_match(ipv6, regex_ipv6)) {
98             throw std::runtime_error("Invalid IPv6 address: " + ipv6);
99         }
100         if (!regex_match(domain, regex_domain)) {
101             throw std::runtime_error("Invalid Domain: " + domain);
102         }
103         if (!regex_match(password, regex_password)) {
104            throw std::runtime_error("Invalid Password: " + password);
105         }
106         
107         /* read configuration */
108         pt::ptree config;
109         pt::ini_parser::read_ini(CONFIG_FILE, config);
110         std::string nsupdate = config.get<std::string>("nsupdate");
111         unsigned server_port = config.get<unsigned>("port", 53);
112         std::string keyfile = config.get<std::string>("keyfile", "");
113         std::string key = config.get<std::string>("key", "");
114         
115         /* check for some invalid configurations */
116         if (keyfile.size() > 0 && key.size() > 0) {
117             std::cerr << "You can only have either a keyfile or a key set. Please fix your configuration." << std::endl;
118             exit(1);
119         }
120         
121         /* Given the domain, check whether the password matches */
122         optional<std::string> correct_password = config.get_optional<std::string>(pt::ptree::path_type(domain+"/password", '/'));
123         if (!correct_password || *correct_password != password) {
124             std::cerr << "Password incorrect." << std::endl;
125             exit(1);
126         }
127         
128         /* preapre the pipe */
129         int pipe_ends[2];
130         if (pipe(pipe_ends) < 0) {
131             std::cerr << "Error opening pipe." << std::endl;
132             exit(1);
133         }
134
135         /* Launch nsupdate */
136         pid_t child_pid = fork();
137         if (child_pid < 0) {
138             std::cerr << "Error while forking." << std::endl;
139             exit(1);
140         }
141         if (child_pid == 0) {
142             /* We're in the child */
143             /* Close write end, use read end as stdin */
144             close(pipe_ends[1]);
145             if (dup2(pipe_ends[0], fileno(stdin)) < 0) {
146                 std::cerr << "There was an error redirecting stdin." << std::endl;
147                 exit(1);
148             }
149             /* exec nsupdate */
150             if (keyfile.size() > 0) {
151                 execl(nsupdate.c_str(), nsupdate.c_str(), "-k", keyfile.c_str(), "-p", std::to_string(server_port).c_str(), "-l", (char *)NULL);
152             }
153             else if (key.size() > 0) {
154                 execl(nsupdate.c_str(), nsupdate.c_str(), "-y", key.c_str(), "-p", std::to_string(server_port).c_str(), "-l", (char *)NULL);
155             }
156             else {
157                 execl(nsupdate.c_str(), nsupdate.c_str(), "-p", std::to_string(server_port).c_str(), "-l", (char *)NULL);
158             }
159             /* There was an error */
160             std::cerr << "There was an error executing nsupdate." << std::endl;
161             exit(1);
162         }
163         
164         /* Send it the command */
165         if (haveIPv4) {
166             write(pipe_ends[1], "update delete ");
167             write(pipe_ends[1], domain.c_str());
168             write(pipe_ends[1], ". A\n");
169             
170             if (!ipv4.empty()) {
171                 write(pipe_ends[1], "update add ");
172                 write(pipe_ends[1], domain.c_str());
173                 write(pipe_ends[1], ". 60 A ");
174                 write(pipe_ends[1], ipv4.c_str());
175                 write(pipe_ends[1], "\n");
176             }
177         }
178         
179         if (haveIPv6) {
180             write(pipe_ends[1], "update delete ");
181             write(pipe_ends[1], domain.c_str());
182             write(pipe_ends[1], ". AAAA\n");
183             
184             if (!ipv6.empty()) {
185                 write(pipe_ends[1], "update add ");
186                 write(pipe_ends[1], domain.c_str());
187                 write(pipe_ends[1], ". 60 AAAA ");
188                 write(pipe_ends[1], ipv6.c_str());
189                 write(pipe_ends[1], "\n");
190             }
191         }
192         
193         write(pipe_ends[1], "send\n");
194         
195         /* Close both ends */
196         close(pipe_ends[0]);
197         close(pipe_ends[1]);
198         
199         /* Wait for child to be gone */
200         int child_status;
201         waitpid(child_pid, &child_status, 0);
202         if (child_status != 0) {
203             std::cerr << "There was an error in the child." << std::endl;
204             exit(1);
205         }
206     }
207     catch(std::exception &e) {
208         std::cout << e.what() << "\n";
209         return 1;
210     } 
211     
212     return 0;
213 }