diff --git a/serve.py b/serve.py index a94709a..56a0068 100644 --- a/serve.py +++ b/serve.py @@ -12,6 +12,19 @@ GRAB_SYMLINKS = True # Toggle symlink handling on/off SYMLINK_DIR = os.path.abspath("symlinks") # Directory to resolve symlinks to STATIC_SHARE_RESTRICTION = True # Toggle restriction on/off +def get_project_directory(project_id): + project_dir = os.path.join(BASE_DIR, "projects", project_id) + if GRAB_SYMLINKS and not os.path.exists(project_dir): + # If the project does not exist in the projects dir, check if it exists in the symlinks dir + if not os.path.exists(SYMLINK_DIR) or not os.path.isdir(SYMLINK_DIR): + return None + # Find the symlink directory that starts with the project_id + symlink_dirs = [d for d in os.listdir(SYMLINK_DIR) if os.path.isdir(os.path.join(SYMLINK_DIR, d))] + for symlink_dir in symlink_dirs: + if symlink_dir.startswith(project_id): + return os.path.join(SYMLINK_DIR, symlink_dir) + return project_dir + class CustomHandler(http.server.SimpleHTTPRequestHandler): def translate_path(self, path): path = unquote(path) @@ -19,7 +32,7 @@ class CustomHandler(http.server.SimpleHTTPRequestHandler): if path.startswith("/project/"): parts = path[len("/project/"):].split("/", 1) project_id = parts[0] - project_dir = os.path.join(BASE_DIR, "projects", project_id) + project_dir = get_project_directory(project_id) # Check .staticshare restriction if STATIC_SHARE_RESTRICTION: @@ -51,24 +64,7 @@ class CustomHandler(http.server.SimpleHTTPRequestHandler): parts = path[len("/project/"):].split("/", 1) project_id = parts[0] - project_dir = os.path.join(BASE_DIR, "projects", project_id) - - if GRAB_SYMLINKS and not os.path.exists(project_dir): - # if the project does not exist in the projects dir, check if it exists in the symlinks dir - # get all directories in the symlinks directory - if not os.path.exists(SYMLINK_DIR) or not os.path.isdir(SYMLINK_DIR): - self.send_response(404) - self.send_header("Content-type", "text/html") - self.end_headers() - with open(os.path.join(BASE_DIR, "404.html"), "rb") as f: - self.wfile.write(f.read()) - return - # find the symlink directory that starts with the project_id - symlink_dirs = [d for d in os.listdir(SYMLINK_DIR) if os.path.isdir(os.path.join(SYMLINK_DIR, d))] - for symlink_dir in symlink_dirs: - if symlink_dir.startswith(project_id): - project_dir = os.path.join(SYMLINK_DIR, symlink_dir) - break + project_dir = get_project_directory(project_id) if STATIC_SHARE_RESTRICTION: staticshare_path = os.path.join(project_dir, ".staticshare")