#!/usr/bin/python3
import logging, logging.handlers
import os, sys, shlex, pwd
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
# Configuration
shell = None # set to "/bin/bash" or similar to allow shell access
rrsync = "/usr/local/bin/schsh-rrsync" # path to the restricted rsync script - if available, it will be used to further restrict rsync access

def allowSCP(run, runstr):
	if len(run) != 3: return False
	if run[0] != "scp": return False
	if run[1] not in ("-f", "-t"): return False
	if run[2].startswith('-'): return False
	run[0] = "/usr/bin/scp"
	return True

def allowRSync(run, runstr):
	if len(run) < 3: return False
	if run[0] != "rsync": return False
	if run[1] != "--server": return False
	if rrsync is None:
		# rrsync is not available, let's hope this is enough protection
		run[0] = "/usr/bin/rsync"
		return True
	run[:] = [rrsync, "/", runstr] # allow access to the entire chroot
	return True

def allowSFTP(run, runstr):
	return runstr == "/usr/lib/openssh/sftp-server"

allowCommands = [allowSCP, allowRSync, allowSFTP]

# END of Configuration
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
# DO NOT TOUCH ANYTHING BELOW THIS LINE


logger = logging.getLogger("schsh")
logger.setLevel(logging.INFO)
logger.addHandler(logging.handlers.SysLogHandler(address = '/dev/log',
						facility = logging.handlers.SysLogHandler.LOG_AUTH))

def get_username():
    return pwd.getpwuid(os.getuid()).pw_name

def log(msg, lvl = logging.INFO):
	logger.log(lvl, "%s[%d]: <%s> %s" % ("schsh", os.getpid(), get_username(), msg))

def logquit(msg):
	log(msg, logging.ERROR)
	sys.exit(1)

def commandAllowed(run, runstr):
	for allowed in allowCommands:
		if allowed(run, runstr):
			return True
	return False

# parse arguments
run = []
if len(sys.argv) == 1:
	if shell is None:
		print("No shell for you!")
		logquit("Shell access not allowed")
	else:
		run = [shell]
elif len(sys.argv) == 3 and sys.argv[1] == "-c":
	# check if the command is allowed, and add path
	run = shlex.split(sys.argv[2])
	if commandAllowed(run, sys.argv[2]): # this may change run, but that's okay
		log("Running "+str(run))
	else:
		print("You are not allowed to run this command.")
		logquit("Attempt to run invalid command '"+sys.argv[2]+"'")
else:
	logquit("Invalid arguments for schsh: "+str(sys.argv))

assert len(run) > 0
os.execl("/usr/bin/schroot", "/usr/bin/schroot", "-c", "schsh-"+get_username(), "-d", "/data", "--", *run)