Non-Profit, International

Spirit unsterblich.

C++ 协程 - 异步任务

字数统计:1960 blog

编写异步协程的 Task 实际上就是解决 4 个问题:协程生存期、异常传播、协程同步和结果发布。实际使用中,Promise 负责存储,Task 负责和外部进行交互。

接下来要编写一个最简单的同步协程任务类型,同步指的是它延迟启动并且顺序执行所有代码,因此不需要考虑任何线程安全问题。

该同步协程的用法如下:


sync_task switch_to_thread_pool()
{
    std::cout << "2. " << std::this_thread::get_id() << '\n';
    co_await resume_background();
    std::cout << "3. " << std::this_thread::get_id() << '\n';
}

fire_and_forget main_coro()
{
    std::cout << "1. " << std::this_thread::get_id() << '\n';
    co_await switch_to_thread_pool();
    std::cout << "4. " << std::this_thread::get_id() << '\n';
}

该代码将会使得 1 和 2 输出相同的线程 ID,3 和 4 输出相同的线程 ID。该最简协程实际上就能完成协程的最大作用,避免回调地狱:调用者(主协程) 不阻塞 等待任务的完成,而是被主动 恢复

在阅读下文之前,读者可以自行尝试编写这样的任务类型,然后再阅读下文。

由于同步协程的生存期完全内嵌于调用者,因此我们实际上不需要关心该问题,让协程在最终暂停点不暂停使协程自然销毁即可。

由于协程中抛出的异常是储存在 Promise 中的,因此 Promise 需要有一个非静态成员 std::exception_ptr exc 来储存该异常。

当前最简实现暂且只能返回 void,随后会补充它。

由于同步协程的执行在同步点发起(恢复),因此我们需要为 sync_task 编写 operator co_await 来调用协程并且恢复到调用者。

所以问题在于,如何实现同步?答案其实很简单,在 Promise 中储存下一个需要执行的协程的协程句柄即可。

因此,我们得到了以下代码:


template<typename T=void>
struct sync_task;

template<>
struct sync_task<>
{
    struct promise_type;
    std::coroutine_handle<promise_type> handle;
    struct promise_type
    {
        std::coroutine_handle<> next;
        std::exception_ptr exc;
    public:
        promise_type() noexcept
        {
        }
        sync_task get_return_object()
        {
            return {};
        }
        std::suspend_always initial_suspend()
        {
            return {};
        }
        auto final_suspend() noexcept
        {
            struct final_awaiter
            {
                promise_type& promise;
                bool await_ready() noxcept
                {
                    return bool(promise.next);
                }
                std::coroutine_handle<> await_suspend() noexcept
                {
                    return promise.next;
                }
                void await_resume() noexcept
                {
                }
            };
            return final_awaiter{ *this };
        }
        void return_void()
        {
        }
        void unhandled_exception() noexcept
        {
            exc = std::current_exception();
        }
    };
    auto operator co_await()
    {
        struct sync_awaiter
        {
            std::coroutine_handle<promise_type> handle;
            bool await_ready() {
                return false;
            }
            std::coroutine_handle<> await_suspend(std::coroutine_handle<> next)
            {
                handle.promise().next = next;
                return handle;
            }
            void await_resume()
            {
                auto& exc = handle.promise().exc;
                if (exc)
                    std::rethrow_exception(exc);
            }
        };
        return sync_awaiter{ handle };
    }
};

第一个秘密在于 final_suspend 返回的 final_awaiter 中。

首先在 await_ready 检查 next 是否指代协程,如果 next 不指代协程,那么说明没人和它同步(不存在下一个要执行的任务)。

然后,在 await_suspend 中返回 next。第四章中讲过,如果 await_suspend 返回 std::coroutine_handle<>,那么返回的协程句柄会立即被调用,因此实际上该相当于在函数体内调用再返回 void,但与之不同的是,前者可以避免栈溢出,该技术被称作对称转移(Symmetric Transfer)。

考虑以下代码:


sync_task switch_to_thread_pool()
{
    co_await resume_background();
}

fire_and_forget main_coro()
{
    co_await switch_to_thread_pool();
    co_await switch_to_thread_pool();
    co_await switch_to_thread_pool();
    ...
}

上述代码实际上不存在栈溢出问题,因为将协程句柄发送到线程池后,每次调用协程句柄,栈总是从事件循环中开始增长,从而无论等待多少次,栈的增长长度都是固定的。

而以下代码则可能会有不同:


sync_task switch_to_thread_pool()
{
    co_return;
}

fire_and_forget main_coro()
{
    co_await switch_to_thread_pool();
    co_await switch_to_thread_pool();
    co_await switch_to_thread_pool();
    ...
}

由于 switch_to_thread_pool 不再将自己发送到线程池,因此对于非对称转移的写法:


void await_suspend()
{
    promise.next();
}

每次恢复下一个任务都将会在之前的栈上进行增长,如果等待的次数过多就会导致栈溢出。

而对称转移的写法将协程句柄返回给调用者,消除了此次的栈增长,避免了溢出。

第二个秘密在于 operator co_await 返回的 sync_awaitersync_awaiter 的作用是将自己变为 next,同时恢复在初始暂停点暂停的协程。

co_await switch_to_thread_pool(); 使协程恢复的时候,await_resume 负责发布结果并且在存在异常时抛出异常。

现在,可以很容易的扩展无值的任务为有值的:


template<typename T>
struct sync_task
{
    struct promise_type;
    std::coroutine_handle<promise_type> handle;
    struct promise_type
    {
        std::coroutine_handle<> next;
        std::exception_ptr exc;
        std::optional<T> res;
    public:
        promise_type() noexcept
        {
        }
        sync_task get_return_object()
        {
            return {};
        }
        std::suspend_always initial_suspend()
        {
            return {};
        }
        auto final_suspend() noexcept
        {
            struct final_awaiter
            {
                promise_type& promise;
                bool await_ready() noxcept
                {
                    return bool(promise.next);
                }
                std::coroutine_handle<> await_suspend() noexcept
                {
                    return promise.next;
                }
                void await_resume() noexcept
                {
                }
            };
            return final_awaiter{ *this };
        }
        template<typename T>
        void return_value(T&& t)
        {
            res = std::forward<T>(t);
        }
        void unhandled_exception() noexcept
        {
            exc = std::current_exception();
        }
    };
    auto operator co_await()
    {
        struct sync_awaiter
        {
            std::coroutine_handle<promise_type> handle;
            bool await_ready() {
                return false;
            }
            std::coroutine_handle<> await_suspend(std::coroutine_handle<> next)
            {
                handle.promise().next = next;
                return handle;
            }
            T await_resume()
            {
                auto& promise = handle.promise();
                auto& exc = promise.exc;
                if (exc)
                    std::rethrow_exception(exc);
                return std::move(promise.res.value());
            }
        };
        return sync_awaiter{ handle };
    }
};


若无特殊声明,本人原创文章以 CC BY-SA 3.0 许可协议 提供。