Files
2026-05-13 08:54:08 +07:00

378 lines
11 KiB
C++

#include <algorithm>
#include <cerrno>
#include <chrono>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <random>
#include <stdexcept>
#include <string>
#include <sys/wait.h>
#include <unistd.h>
#include <vector>
constexpr int NIL = -1;
struct Node {
int32_t value;
int next;
};
struct Options {
size_t size = 100000;
int max_depth = 3;
size_t min_size = 4096;
unsigned seed = 1337;
bool print = false;
};
struct SortResult {
std::vector<int32_t> values;
};
static void throw_errno(const std::string& what) {
throw std::runtime_error(what + ": " + std::strerror(errno));
}
static void close_checked(int fd) {
if (fd >= 0) {
while (close(fd) < 0 && errno == EINTR) {}
}
}
static void write_all(int fd, const void* data, size_t bytes) {
const char* p = static_cast<const char*>(data);
while (bytes > 0) {
ssize_t n = write(fd, p, bytes);
if (n < 0) {
if (errno == EINTR) continue;
throw_errno("write");
}
if (n == 0) throw std::runtime_error("write returned zero");
p += n;
bytes -= static_cast<size_t>(n);
}
}
static void read_all(int fd, void* data, size_t bytes) {
char* p = static_cast<char*>(data);
while (bytes > 0) {
ssize_t n = read(fd, p, bytes);
if (n < 0) {
if (errno == EINTR) continue;
throw_errno("read");
}
if (n == 0) throw std::runtime_error("unexpected EOF");
p += n;
bytes -= static_cast<size_t>(n);
}
}
static void send_values(int fd, const std::vector<int32_t>& values) {
uint64_t n = static_cast<uint64_t>(values.size());
write_all(fd, &n, sizeof(n));
if (!values.empty()) write_all(fd, values.data(), values.size() * sizeof(values[0]));
}
static std::vector<int32_t> recv_values(int fd) {
uint64_t n = 0;
read_all(fd, &n, sizeof(n));
if (n > static_cast<uint64_t>(SIZE_MAX / sizeof(int32_t))) {
throw std::runtime_error("input too large");
}
std::vector<int32_t> values(static_cast<size_t>(n));
if (!values.empty()) read_all(fd, values.data(), values.size() * sizeof(values[0]));
return values;
}
static int merge_lists(std::vector<Node>& nodes, int left, int right) {
if (left == NIL) return right;
if (right == NIL) return left;
int head = NIL;
int tail = NIL;
auto append = [&](int idx) {
if (head == NIL) {
head = idx;
tail = idx;
} else {
nodes[tail].next = idx;
tail = idx;
}
};
while (left != NIL && right != NIL) {
if (nodes[left].value <= nodes[right].value) {
int next = nodes[left].next;
append(left);
left = next;
} else {
int next = nodes[right].next;
append(right);
right = next;
}
}
if (left != NIL) nodes[tail].next = left;
if (right != NIL) nodes[tail].next = right;
return head;
}
static void split_list(std::vector<Node>& nodes, int head, int& left, int& right) {
if (head == NIL || nodes[head].next == NIL) {
left = head;
right = NIL;
return;
}
int slow = head;
int fast = nodes[head].next;
while (fast != NIL) {
fast = nodes[fast].next;
if (fast != NIL) {
slow = nodes[slow].next;
fast = nodes[fast].next;
}
}
left = head;
right = nodes[slow].next;
nodes[slow].next = NIL;
}
static int local_sort(std::vector<Node>& nodes, int head) {
if (head == NIL || nodes[head].next == NIL) return head;
int left = NIL;
int right = NIL;
split_list(nodes, head, left, right);
left = local_sort(nodes, left);
right = local_sort(nodes, right);
return merge_lists(nodes, left, right);
}
static std::vector<Node> make_nodes(const std::vector<int32_t>& values) {
std::vector<Node> nodes(values.size());
for (size_t i = 0; i < values.size(); ++i) {
nodes[i].value = values[i];
nodes[i].next = (i + 1 == values.size()) ? NIL : static_cast<int>(i + 1);
}
return nodes;
}
static std::vector<int32_t> collect_values(const std::vector<Node>& nodes, int head) {
std::vector<int32_t> out;
out.reserve(nodes.size());
for (int cur = head; cur != NIL; cur = nodes[cur].next) {
out.push_back(nodes[cur].value);
}
return out;
}
static std::vector<int32_t> sort_values_as_list(std::vector<int32_t> values) {
if (values.size() < 2) return values;
std::vector<Node> nodes = make_nodes(values);
int head = local_sort(nodes, 0);
return collect_values(nodes, head);
}
static void split_values_as_list(const std::vector<int32_t>& values,
std::vector<int32_t>& left_values,
std::vector<int32_t>& right_values) {
std::vector<Node> nodes = make_nodes(values);
int left = NIL;
int right = NIL;
split_list(nodes, values.empty() ? NIL : 0, left, right);
left_values = collect_values(nodes, left);
right_values = collect_values(nodes, right);
}
static std::vector<int32_t> merge_values_as_list(const std::vector<int32_t>& left_values,
const std::vector<int32_t>& right_values) {
std::vector<Node> nodes;
nodes.reserve(left_values.size() + right_values.size());
int left = NIL;
int left_tail = NIL;
for (int32_t value : left_values) {
int idx = static_cast<int>(nodes.size());
nodes.push_back({value, NIL});
if (left == NIL) left = idx;
else nodes[left_tail].next = idx;
left_tail = idx;
}
int right = NIL;
int right_tail = NIL;
for (int32_t value : right_values) {
int idx = static_cast<int>(nodes.size());
nodes.push_back({value, NIL});
if (right == NIL) right = idx;
else nodes[right_tail].next = idx;
right_tail = idx;
}
int head = merge_lists(nodes, left, right);
return collect_values(nodes, head);
}
static SortResult parallel_sort(std::vector<int32_t> values, int depth, const Options& opt);
static pid_t spawn_child(const std::vector<int32_t>& part,
int depth,
const Options& opt,
int& result_fd) {
int to_child[2] = {-1, -1};
int from_child[2] = {-1, -1};
if (pipe(to_child) < 0) throw_errno("pipe to_child");
if (pipe(from_child) < 0) throw_errno("pipe from_child");
pid_t pid = fork();
if (pid < 0) throw_errno("fork");
if (pid == 0) {
try {
close_checked(to_child[1]);
close_checked(from_child[0]);
std::vector<int32_t> input = recv_values(to_child[0]);
close_checked(to_child[0]);
SortResult result = parallel_sort(std::move(input), depth, opt);
send_values(from_child[1], result.values);
close_checked(from_child[1]);
_exit(0);
} catch (...) {
_exit(2);
}
}
close_checked(to_child[0]);
close_checked(from_child[1]);
send_values(to_child[1], part);
close_checked(to_child[1]);
result_fd = from_child[0];
return pid;
}
static SortResult parallel_sort(std::vector<int32_t> values, int depth, const Options& opt) {
if (values.size() < 2 || depth >= opt.max_depth || values.size() <= opt.min_size) {
return {sort_values_as_list(std::move(values))};
}
std::vector<int32_t> left;
std::vector<int32_t> right;
split_values_as_list(values, left, right);
int left_fd = -1;
int right_fd = -1;
pid_t left_pid = spawn_child(left, depth + 1, opt, left_fd);
pid_t right_pid = spawn_child(right, depth + 1, opt, right_fd);
std::vector<int32_t> sorted_left = recv_values(left_fd);
std::vector<int32_t> sorted_right = recv_values(right_fd);
close_checked(left_fd);
close_checked(right_fd);
int left_status = 0;
int right_status = 0;
while (waitpid(left_pid, &left_status, 0) < 0 && errno == EINTR) {}
while (waitpid(right_pid, &right_status, 0) < 0 && errno == EINTR) {}
if (!WIFEXITED(left_status) || WEXITSTATUS(left_status) != 0) {
throw std::runtime_error("left child failed");
}
if (!WIFEXITED(right_status) || WEXITSTATUS(right_status) != 0) {
throw std::runtime_error("right child failed");
}
return {merge_values_as_list(sorted_left, sorted_right)};
}
static Options parse_args(int argc, char** argv) {
Options opt;
for (int i = 1; i < argc; ++i) {
std::string s = argv[i];
auto value = [&](const std::string& name) -> std::string {
if (i + 1 >= argc) throw std::runtime_error("missing value for " + name);
return argv[++i];
};
if (s == "--size" || s == "-n") {
opt.size = std::stoull(value(s));
} else if (s == "--depth" || s == "-d") {
opt.max_depth = std::stoi(value(s));
} else if (s == "--min-size" || s == "-m") {
opt.min_size = std::stoull(value(s));
} else if (s == "--seed") {
opt.seed = static_cast<unsigned>(std::stoul(value(s)));
} else if (s == "--print") {
opt.print = true;
} else if (s == "--help" || s == "-h") {
std::cout << "Usage: ./pipe_sort [--size N] [--depth D] [--min-size M] "
<< "[--seed S] [--print]\n";
std::exit(0);
} else {
throw std::runtime_error("unknown argument: " + s);
}
}
if (opt.max_depth < 0) throw std::runtime_error("depth must be non-negative");
if (opt.size == 0) throw std::runtime_error("size must be positive");
return opt;
}
int main(int argc, char** argv) {
try {
Options opt = parse_args(argc, argv);
std::vector<int32_t> values(opt.size);
std::mt19937 rng(opt.seed);
std::uniform_int_distribution<int32_t> dist(-100000000, 100000000);
for (int32_t& value : values) value = dist(rng);
const auto t1 = std::chrono::steady_clock::now();
SortResult result = parallel_sort(std::move(values), 0, opt);
const auto t2 = std::chrono::steady_clock::now();
const double elapsed = std::chrono::duration<double>(t2 - t1).count();
const bool ok = std::is_sorted(result.values.begin(), result.values.end()) &&
result.values.size() == opt.size;
if (opt.print) {
for (int32_t value : result.values) std::cout << value << ' ';
std::cout << '\n';
}
std::cerr << "STAT: program=pipe size=" << opt.size
<< " depth=" << opt.max_depth
<< " min_size=" << opt.min_size
<< " valid=" << (ok ? 1 : 0)
<< " time=" << elapsed << " sec\n";
return ok ? 0 : 3;
} catch (const std::exception& e) {
std::cerr << "ERROR: " << e.what() << "\n";
return 1;
}
}