Skip to content

Instantly share code, notes, and snippets.

@blockloop
Forked from kevinkreiser/spmc.cpp
Created May 28, 2018 16:14
Show Gist options
  • Save blockloop/2691308bccfde7a6703ef7c8337fdff2 to your computer and use it in GitHub Desktop.
Save blockloop/2691308bccfde7a6703ef7c8337fdff2 to your computer and use it in GitHub Desktop.
Single Producer Multiple Consumer with ZeroMQ/ZMQ/ØMQ Butterfly Pattern
/* install libzmq-dev:
*
* wget http://download.zeromq.org/zeromq-4.1.0-rc1.tar.gz
* tar pxvf zeromq-4.1.0-rc1.tar.gz
* cd zeromq-4.1.0
* ./autogen.sh
* ./configure
* make
* sudo make install
*
* install c++ binding:
*
* wget -P /usr/local/include https://raw.githubusercontent.com/zeromq/cppzmq/master/zmq.hpp
*
* run this sample:
*
* g++ spmc.cpp -o spmc -std=c++11 -O2 -g -lzmq
* time spmc 10000000 1
* time spmc 10000000 8
*
*
* trying to simulate a system that is pipelined and parallelized
* ie. the zmq "butterfly" or "parallel pipeline" pattern
* http://zeromq.org/tutorials:butterfly
*
* (http server) request_producer
* / | \
* / | \
* (parallelize) worker worker worker ...
* \ | /
* \ | /
* (rebalance) intermediate_router
* / | \
* / | \
* (parallelize) worker worker worker ...
* \ | /
* \ | /
* (reply) response_collector
*/
#include <zmq.hpp>
#include <thread>
#include <functional>
#include <memory>
#include <string>
#include <list>
#include <set>
#include <iostream>
namespace {
//produce messages to a group of workers
class producer {
using produce_function_t = std::function<std::list<zmq::message_t> ()>;
public:
producer(const std::shared_ptr<zmq::context_t>& context, const char* push_endpoint, const produce_function_t& produce_function):
context(context), push_socket(*this->context, ZMQ_PUSH), produce_function(produce_function) {
int high_water_mark = 0;
push_socket.setsockopt(ZMQ_SNDHWM, &high_water_mark, sizeof(high_water_mark));
push_socket.bind(push_endpoint);
}
void produce() {
while(true) {
auto messages = produce_function();
if(messages.size() == 0)
break;
for(auto& message : messages)
push_socket.send(message);
}
}
protected:
std::shared_ptr<zmq::context_t> context;
zmq::socket_t push_socket;
produce_function_t produce_function;
};
//forward message from one group of workers to the next
class router {
using route_function_t = std::function<void (const zmq::message_t&)>;
public:
router(const std::shared_ptr<zmq::context_t>& context, const char* pull_endpoint, const char* push_endpoint, const route_function_t& route_function = [](const zmq::message_t&){}):
context(context), pull_socket(*this->context, ZMQ_PULL), push_socket(*this->context, ZMQ_PUSH), route_function(route_function) {
int high_water_mark = 0;
pull_socket.setsockopt(ZMQ_RCVHWM, &high_water_mark, sizeof(high_water_mark));
pull_socket.bind(pull_endpoint);
push_socket.setsockopt(ZMQ_SNDHWM, &high_water_mark, sizeof(high_water_mark));
push_socket.bind(push_endpoint);
}
void route() {
//keep forwarding messages
while(true) {
zmq::message_t message;
pull_socket.recv(&message);
route_function(message);
push_socket.send(message);
}
}
protected:
std::shared_ptr<zmq::context_t> context;
zmq::socket_t pull_socket;
zmq::socket_t push_socket;
route_function_t route_function;
};
//perform an action on a message and pass to next router/collector
class worker {
using work_function_t = std::function<std::list<zmq::message_t> (const zmq::message_t&)>;
public:
worker(const std::shared_ptr<zmq::context_t>& context, const char* pull_endpoint, const char* push_endpoint, const work_function_t& work_function):
context(context), pull_socket(*this->context, ZMQ_PULL), push_socket(*this->context, ZMQ_PUSH), work_function(work_function) {
int high_water_mark = 0;
pull_socket.setsockopt(ZMQ_RCVHWM, &high_water_mark, sizeof(high_water_mark));
pull_socket.connect(pull_endpoint);
push_socket.setsockopt(ZMQ_SNDHWM, &high_water_mark, sizeof(high_water_mark));
push_socket.connect(push_endpoint);
}
void work() {
while(true) {
zmq::message_t job;
pull_socket.recv(&job);
auto messages = work_function(job);
for(auto& message : messages)
push_socket.send(message);
}
}
protected:
std::shared_ptr<zmq::context_t> context;
zmq::socket_t pull_socket;
zmq::socket_t push_socket;
work_function_t work_function;
};
//collects completed work after passing through the entire butterfly
class collector {
using collect_function_t = std::function<bool (const zmq::message_t&)>;
public:
collector(const std::shared_ptr<zmq::context_t>& context, const char* pull_endpoint, const collect_function_t& collect_function):
context(context), pull_socket(*this->context, ZMQ_PULL), collect_function(collect_function) {
int high_water_mark = 0;
pull_socket.setsockopt(ZMQ_RCVHWM, &high_water_mark, sizeof(high_water_mark));
pull_socket.bind(pull_endpoint);
}
void collect() {
while(true) {
zmq::message_t message;
pull_socket.recv(&message);
if(collect_function(message))
break;
}
}
protected:
std::shared_ptr<zmq::context_t> context;
zmq::socket_t pull_socket;
collect_function_t collect_function;
};
}
int main(int argc, char** argv) {
//number of jobs to do
size_t requests = 10;
if(argc > 1)
requests = std::stoul(argv[1]);
//number of workers to use at each stage
size_t worker_concurrency = 1;
if(argc > 2)
worker_concurrency = std::stoul(argv[2]);
//change these to tcp://known.ip.address.with:port if you want to do this across machines
std::shared_ptr<zmq::context_t> context_ptr = std::make_shared<zmq::context_t>(1);
const char* requests_in = "ipc://requests_in";
const char* parsed_in = "ipc://parsed_in";
const char* primes_in = "ipc://primes_in";
const char* primes_out = "ipc://primes_out";
//request producer
size_t produced_requests = 0;
producer request_producer(context_ptr, requests_in,
[requests, &produced_requests]() {
std::list<zmq::message_t> messages;
if(produced_requests != requests)
{
auto request = std::to_string(produced_requests * 2 + 3);
messages.emplace_back(request.size());
std::copy(request.begin(), request.end(), static_cast<char*>(messages.back().data()));
++produced_requests;
}
return messages;
}
);
//request parsers
std::list<std::thread> parse_worker_threads;
for(size_t i = 0; i < worker_concurrency; ++i) {
parse_worker_threads.emplace_back(std::bind(&worker::work, worker(context_ptr, requests_in, parsed_in,
[] (const zmq::message_t& message) {
//parse the string into a size_t
std::list<zmq::message_t> messages;
messages.emplace_back(sizeof(size_t));
const size_t possible_prime = std::stoul(std::string(static_cast<const char*>(message.data()), message.size()));
*static_cast<size_t*>(messages.back().data()) = possible_prime;
return messages;
}
)));
parse_worker_threads.back().detach();
}
//router from parsed requests to prime computation workers
std::thread prime_router_thread(std::bind(&router::route, router(context_ptr, parsed_in, primes_in)));
prime_router_thread.detach();
//prime computers
std::list<std::thread> prime_worker_threads;
for(size_t i = 0; i < worker_concurrency; ++i) {
prime_worker_threads.emplace_back(std::bind(&worker::work, worker(context_ptr, primes_in, primes_out,
[] (const zmq::message_t& message) {
//check if its prime
const size_t prime = *static_cast<const size_t*>(message.data());
size_t divisor = 2;
size_t high = prime;
while(divisor < high) {
if(prime % divisor == 0)
break;
high = prime / divisor;
++divisor;
}
//if it was prime send it back unmolested, else send back 2 which we know is prime
std::list<zmq::message_t> messages;
messages.emplace_back(sizeof(size_t));
*static_cast<size_t*>(messages.back().data()) = (divisor >= high ? prime : static_cast<size_t>(2));
return messages;
}
)));
prime_worker_threads.back().detach();
}
//result collector
std::set<size_t> primes = {2};
size_t collected_results = 0;
std::thread collector_thread(std::bind(&collector::collect, collector(context_ptr, primes_out,
[requests, &primes, &collected_results] (const zmq::message_t& message) {
primes.insert(*static_cast<const size_t*>(message.data()));
++collected_results;
return collected_results == requests;
}
)));
//started last so we don't miss requests from it
std::thread producer_thread(std::bind(&producer::produce, &request_producer));
//wait for the collector to get all the jobs
collector_thread.join();
producer_thread.join();
//show primes
//for(const auto& prime : primes)
// std::cout << prime << " | ";
std::cout << primes.size() << std::endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment