Rust 多线程¶
TODO:
Rust 线程池简单实现¶
rust #并发 #并行¶
use std::sync::{mpsc, Arc, Mutex};
use std::thread;
// 定义一个 Job 类型别名,它是一个 Boxed 的闭包
// FnOnce: 闭包可以被调用一次
// Send: 闭包可以被安全地发送到另一个线程
// 'static: 闭包不包含任何非静态生命周期的引用
type Job = Box<dyn FnOnce() + Send + 'static>;
// 线程池与工作线程之间传递的消息
enum Message {
NewJob(Job), // 新的任务
Terminate, // 终止信号
}
pub struct ThreadPool {
workers: Vec<Worker>,
sender: Option<mpsc::Sender<Message>>,
}
struct Worker {
id: usize,
thread: Option<thread::JoinHandle<()>>,
}
impl ThreadPool {
/// 创建一个新的 ThreadPool。
///
/// `threads` 参数是池中线程的数量。
///
/// # Panics
///
/// 如果 `threads` 为 0,`new` 函数会 panic。
pub fn new(threads: u32) -> Result<Self, PoolCreationError> {
if threads == 0 {
return Err(PoolCreationError("Thread count must be greater than 0".to_string()));
}
let (sender, receiver) = mpsc::channel();
let receiver = Arc::new(Mutex::new(receiver)); // 允许多个 worker 共享 receiver
let mut workers = Vec::with_capacity(threads as usize);
for id in 0..threads {
workers.push(Worker::new(id as usize, Arc::clone(&receiver)));
}
Ok(ThreadPool {
workers,
sender: Some(sender),
})
}
/// 将一个任务提交到线程池中执行。
///
/// `job` 是一个闭包,它将被发送到线程池中的某个空闲线程执行。
pub fn spawn<F>(&self, job: F)
where
F: FnOnce() + Send + 'static,
{
let job = Box::new(job);
// 使用 unwrap 是因为 sender 只会在 ThreadPool drop 时变为 None
self.sender
.as_ref()
.unwrap()
.send(Message::NewJob(job))
.expect("ThreadPool::spawn: Failed to send job to worker. Channel might be closed.");
}
}
// 当 ThreadPool 被丢弃时,确保所有线程都完成工作并被清理
impl Drop for ThreadPool {
fn drop(&mut self) {
println!("Sending terminate message to all workers.");
if let Some(sender) = self.sender.as_ref() {
for _ in &self.workers {
// 尽力发送,如果接收端已经关闭也无妨
let _ = sender.send(Message::Terminate);
}
}
// 关闭 sender,这样 worker 在尝试接收新任务时会知道没有更多任务了
// 这一步很重要,因为如果 worker 正在等待任务,而 sender 没有被 drop,
// 即使发送了 Terminate 消息,worker 也可能在 Terminate 消息被处理前
// 继续阻塞在 recv() 上。drop(sender) 会使 recv() 返回 Err,从而使 worker 退出。
drop(self.sender.take());
println!("Shutting down all workers.");
for worker in &mut self.workers {
println!("Shutting down worker {}", worker.id);
if let Some(thread) = worker.thread.take() { // take 出线程句柄并 join
match thread.join() {
Ok(_) => println!("Worker {} finished.", worker.id),
Err(e) => eprintln!("Worker {} panicked during shutdown: {:?}", worker.id, e),
}
}
}
println!("All workers have been shut down.");
}
}
impl Worker {
fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Message>>>) -> Worker {
let thread = thread::spawn(move || loop {
// 从通道接收消息
// lock() 会阻塞当前线程,直到互斥锁可用
// recv() 会阻塞当前线程,直到有消息可用或通道关闭
let message = receiver.lock().unwrap().recv();
match message {
Ok(Message::NewJob(job)) => {
// println!("Worker {} got a job; executing.", id);
job(); // 执行任务
// println!("Worker {} finished job.", id);
}
Ok(Message::Terminate) => {
// println!("Worker {} was told to terminate.", id);
break; // 收到终止信号,退出循环
}
Err(_) => {
// 当发送端关闭且通道中没有更多消息时,recv 会返回 Err
// println!("Worker {} disconnecting; sender dropped.", id);
break; // 发送端已关闭,退出循环
}
}
});
Worker {
id,
thread: Some(thread),
}
}
}
#[derive(Debug)]
pub struct PoolCreationError(String);
impl std::fmt::Display for PoolCreationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Failed to create thread pool: {}", self.0)
}
}
impl std::error::Error for PoolCreationError {}
// 使用示例
fn main() {
let pool = match ThreadPool::new(4) {
Ok(p) => p,
Err(e) => {
eprintln!("Error creating thread pool: {}", e);
return;
}
};
println!("Thread pool created with 4 threads.");
for i in 0..10 {
let task_id = i;
pool.spawn(move || {
println!("Task {} starting by thread {:?}", task_id, thread::current().id());
thread::sleep(std::time::Duration::from_secs(1));
println!("Task {} finished by thread {:?}", task_id, thread::current().id());
});
}
println!("All tasks submitted. Main thread will sleep for a bit to allow some tasks to start.");
thread::sleep(std::time::Duration::from_millis(500)); // 给一些任务启动的时间
println!("Main thread is now explicitly dropping the pool to trigger shutdown.");
// 当 pool 离开作用域时,它的 Drop trait 实现会被调用,从而优雅地关闭所有线程。
// 或者我们可以显式地 drop(pool);
drop(pool);
println!("Thread pool has been dropped. Main thread exiting.");
// 注意:主线程退出后,如果工作线程没有被正确 join,它们可能会被强制终止。
// Drop impl 确保了 join 的发生。
}