#include #include #include #include #include #include #include #include #include #include #include #include #include #include constexpr int NIL = -1; constexpr int MAX_RESULT_SLOTS = 4096; struct Node { int32_t value; int next; }; struct SharedData { int results[MAX_RESULT_SLOTS]; Node nodes[1]; }; struct Options { size_t size = 100000; int max_depth = 3; size_t min_size = 4096; unsigned seed = 1337; bool print = false; }; struct SortResult { int head = NIL; uint64_t processes = 1; }; static void throw_errno(const std::string& what) { throw std::runtime_error(what + ": " + std::strerror(errno)); } static int merge_lists(Node* nodes, int left, int right) { if (left == NIL) return right; if (right == NIL) return left; int head = NIL; int tail = NIL; auto append = [&](int idx) { if (head == NIL) { head = idx; tail = idx; } else { nodes[tail].next = idx; tail = idx; } }; while (left != NIL && right != NIL) { if (nodes[left].value <= nodes[right].value) { int next = nodes[left].next; append(left); left = next; } else { int next = nodes[right].next; append(right); right = next; } } if (left != NIL) nodes[tail].next = left; if (right != NIL) nodes[tail].next = right; return head; } static void split_list(Node* nodes, int head, int& left, int& right) { if (head == NIL || nodes[head].next == NIL) { left = head; right = NIL; return; } int slow = head; int fast = nodes[head].next; while (fast != NIL) { fast = nodes[fast].next; if (fast != NIL) { slow = nodes[slow].next; fast = nodes[fast].next; } } left = head; right = nodes[slow].next; nodes[slow].next = NIL; } static int local_sort(Node* nodes, int head) { if (head == NIL || nodes[head].next == NIL) return head; int left = NIL; int right = NIL; split_list(nodes, head, left, right); left = local_sort(nodes, left); right = local_sort(nodes, right); return merge_lists(nodes, left, right); } static size_t list_length(const Node* nodes, int head) { size_t n = 0; while (head != NIL) { ++n; head = nodes[head].next; } return n; } static SortResult parallel_sort(SharedData* data, int head, int depth, const Options& opt, int slot) { const size_t n = list_length(data->nodes, head); if (head == NIL || data->nodes[head].next == NIL || depth >= opt.max_depth || n <= opt.min_size) { return {local_sort(data->nodes, head), 1}; } int left = NIL; int right = NIL; split_list(data->nodes, head, left, right); const int left_slot = slot * 2 + 1; const int right_slot = slot * 2 + 2; if (right_slot >= MAX_RESULT_SLOTS) { int sorted_left = local_sort(data->nodes, left); int sorted_right = local_sort(data->nodes, right); return {merge_lists(data->nodes, sorted_left, sorted_right), 1}; } pid_t left_pid = fork(); if (left_pid < 0) throw_errno("fork left"); if (left_pid == 0) { try { SortResult r = parallel_sort(data, left, depth + 1, opt, left_slot); data->results[left_slot] = r.head; _exit(0); } catch (...) { _exit(2); } } pid_t right_pid = fork(); if (right_pid < 0) throw_errno("fork right"); if (right_pid == 0) { try { SortResult r = parallel_sort(data, right, depth + 1, opt, right_slot); data->results[right_slot] = r.head; _exit(0); } catch (...) { _exit(2); } } int left_status = 0; int right_status = 0; while (waitpid(left_pid, &left_status, 0) < 0 && errno == EINTR) {} while (waitpid(right_pid, &right_status, 0) < 0 && errno == EINTR) {} if (!WIFEXITED(left_status) || WEXITSTATUS(left_status) != 0) { throw std::runtime_error("left child failed"); } if (!WIFEXITED(right_status) || WEXITSTATUS(right_status) != 0) { throw std::runtime_error("right child failed"); } return {merge_lists(data->nodes, data->results[left_slot], data->results[right_slot]), static_cast((1ULL << (depth + 2)) - 1ULL)}; } static Options parse_args(int argc, char** argv) { Options opt; for (int i = 1; i < argc; ++i) { std::string s = argv[i]; auto value = [&](const std::string& name) -> std::string { if (i + 1 >= argc) throw std::runtime_error("missing value for " + name); return argv[++i]; }; if (s == "--size" || s == "-n") { opt.size = std::stoull(value(s)); } else if (s == "--depth" || s == "-d") { opt.max_depth = std::stoi(value(s)); } else if (s == "--min-size" || s == "-m") { opt.min_size = std::stoull(value(s)); } else if (s == "--seed") { opt.seed = static_cast(std::stoul(value(s))); } else if (s == "--print") { opt.print = true; } else if (s == "--help" || s == "-h") { std::cout << "Usage: ./shm_sort [--size N] [--depth D] [--min-size M] " << "[--seed S] [--print]\n"; std::exit(0); } else { throw std::runtime_error("unknown argument: " + s); } } if (opt.max_depth < 0) throw std::runtime_error("depth must be non-negative"); if (opt.size == 0) throw std::runtime_error("size must be positive"); return opt; } static bool is_sorted_list(const Node* nodes, int head, size_t expected) { size_t seen = 0; int32_t prev = 0; bool first = true; while (head != NIL) { if (!first && prev > nodes[head].value) return false; first = false; prev = nodes[head].value; head = nodes[head].next; ++seen; } return seen == expected; } int main(int argc, char** argv) { int shmid = -1; SharedData* data = reinterpret_cast(-1); try { Options opt = parse_args(argc, argv); const size_t shm_size = sizeof(SharedData) + sizeof(Node) * (opt.size - 1); shmid = shmget(IPC_PRIVATE, shm_size, IPC_CREAT | 0600); if (shmid < 0) throw_errno("shmget"); data = static_cast(shmat(shmid, nullptr, 0)); if (data == reinterpret_cast(-1)) throw_errno("shmat"); std::fill(data->results, data->results + MAX_RESULT_SLOTS, NIL); std::mt19937 rng(opt.seed); std::uniform_int_distribution dist(-100000000, 100000000); for (size_t i = 0; i < opt.size; ++i) { data->nodes[i].value = dist(rng); data->nodes[i].next = (i + 1 == opt.size) ? NIL : static_cast(i + 1); } const auto t1 = std::chrono::steady_clock::now(); SortResult result = parallel_sort(data, 0, 0, opt, 0); const auto t2 = std::chrono::steady_clock::now(); const double elapsed = std::chrono::duration(t2 - t1).count(); const bool ok = is_sorted_list(data->nodes, result.head, opt.size); if (opt.print) { for (int cur = result.head; cur != NIL; cur = data->nodes[cur].next) { std::cout << data->nodes[cur].value << ' '; } std::cout << '\n'; } std::cerr << "STAT: program=shm size=" << opt.size << " depth=" << opt.max_depth << " min_size=" << opt.min_size << " valid=" << (ok ? 1 : 0) << " time=" << elapsed << " sec\n"; shmdt(data); shmctl(shmid, IPC_RMID, nullptr); return ok ? 0 : 3; } catch (const std::exception& e) { if (data != reinterpret_cast(-1)) shmdt(data); if (shmid >= 0) shmctl(shmid, IPC_RMID, nullptr); std::cerr << "ERROR: " << e.what() << "\n"; return 1; } }