Skip to main content

Build a Postgres Proxy in Elixir using Pattern Matching

· 10 min read
Kyle Hanson

Want to learn how your Application talks to your Database? Build a proxy using Elixir with its powerful pattern matching and gen_tcp to take your database understanding to the next level. In this article we build a Postgres proxy in Elixir to show all you need is a little curiosity to master one of the most popular SQL protocols around.

Building a proxy can be a great way to understand a protocol. Everything you build doesn't need to be used in production in order for it to be valuable to your understanding of the tech. The better you understand the tech, the more powerful the tech becomes.

In this post, we will explore how to accept a socket and proxy it to another, while at the same time parsing the stream. At the end of the post we can use our proxy to control which queries get executed on our postgres server.

Build the Server

To build this proxy, I used the postgres reference documentation regarding the protocol.

This blog post assumes you are familar with Elixir and its Application structure. If you want to learn more about TCP connections and supervision, read the official Elixir article.

Accepting the socket

For this proxy we will open a TCP socket in active mode. In order to understand what active mode is, it's helpful to know what it isn't. When active is false, the VM will stop reading packets from the socket until you call :gen_tcp.recv. When the socket is set to active mode, the VM instead reads data as fast as possible from the socket and uses the processes mailbox as the buffer.

Because we will be reading data as fast as possible from both upstream and downstream, we will need to split out our listener into two processes. One process will read upstream data, parse it into commands, and then forward those commands to the Postgres database. The other process will read the responses from the Postgres database and forward them directly back to the client.

defmodule Statetrace.PostgresProxy do
require Logger

def accept(port) do
{:ok, socket} =
:gen_tcp.listen(port, [:binary, active: true, reuseaddr: true, packet: 0, nodelay: true])

Logger.info("Accepting connections on port #{port}")
loop_acceptor(socket)
end

defp loop_acceptor(socket) do
{:ok, client} = :gen_tcp.accept(socket)

{:ok, outbound} =
:gen_tcp.connect('localhost', 5432, [:binary, active: true, packet: 0, nodelay: true])

{:ok, pid} =
Task.Supervisor.start_child(Statetrace.TaskSupervisor, fn ->
serve_upstream(client, outbound, nil, true)
end)

:ok = :gen_tcp.controlling_process(client, pid)

{:ok, pid2} =
Task.Supervisor.start_child(Statetrace.TaskSupervisor, fn ->
serve_downstream(client, outbound)
end)

:ok = :gen_tcp.controlling_process(outbound, pid2)

loop_acceptor(socket)
end
end

We use :gen_tcp.connect to connect to the real postgres database and spin it off into a loop of its own to pipe responses back to the client.

Serving upstream

Now build the parser for the upstream connection

defmodule Statetrace.PostgresProxy do

...

def serve_upstream(socket, outbound, nil, is_first) do
data = socket |> read_line()

handle_parse(socket, outbound, parse_msg(data, is_first))
end

def serve_upstream(socket, outbound, fun, _is_first) do
data = socket |> read_line()

r = fun.(data)

handle_parse(socket, outbound, r)
end

def read_line(socket) do
receive do
{:tcp, ^socket, data} -> data
end
end
end

To parse the message is very simple. It is a very simple length-prefixed binary format with a message tag byte. The very first message of a connection is the only exception and excludes the tag. There is a chance that the current data we have is not enough to satisfy the length of the message. In this case we will use a continuation so that the next data that comes in, we can check to see if it completes the message.

defmodule Statetrace.PostgresProxy do

...

# On the first message don't extract the tag
defp parse_msg(bin, true) do
# Use pattern matching to extract the length
<<len::unsigned-integer-32, _other_rest::binary>> = bin

case bin do
# Pattern match to see if our binary is big enough
<<msg_body::binary-size(len), final_rest::binary>> ->
{:ok, {{:msgStartup, nil}, msg_body}, final_rest}

_ ->
{:continuation,
fn data ->
handle_continuation(len, {:msgStartup, nil}, bin, data)
end}
end
end

# Pattern match the binary to extract the tag.
defp parse_msg(<<c::size(8), rest::binary>>, false) do
tag = tag_to_msg_type(c)

# Use pattern matching to extract the length
<<len::unsigned-integer-32, _other_rest::binary>> = rest

case rest do
# Pattern match to see if our binary is big enough
<<msg_body::binary-size(len), final_rest::binary>> ->
{:ok, {{tag, c}, msg_body}, final_rest}

_ ->
{:continuation,
fn data ->
handle_continuation(len, {tag, c}, rest, data)
end}
end
end

def handle_continuation(l, tag, other, data) do
new_data = other <> data

case new_data do
<<msg_body::binary-size(l), rest::binary>> ->
{:ok, {tag, msg_body}, rest}

_ ->
{:continuation,
fn data ->
handle_continuation(l, tag, new_data, data)
end}
end
end
end

The first byte in non-connection messages is a tag. We will convert this tag into an atom.

defmodule Statetrace.PostgresProxy do

...

defp tag_to_msg_type(val) do
case val do
?1 -> :msgParseComplete
?2 -> :msgBindComplete
?3 -> :msgCloseComplete
?A -> :msgNotificationResponse
?c -> :msgCopyDone
?C -> :msgCommandComplete
?d -> :msgCopyData
?D -> :msgDataRow
?E -> :msgErrorResponse
?f -> :msgFail
?G -> :msgCopyInResponse
?H -> :msgCopyOutResponse
?I -> :msgEmptyQueryResponse
?K -> :msgBackendKeyData
?n -> :msgNoData
?N -> :msgNoticeResponse
?R -> :msgAuthentication
?s -> :msgPortalSuspended
?S -> :msgParameterStatus
?t -> :msgParameterDescription
?T -> :msgRowDescription
?p -> :msgPasswordMessage
?W -> :CopyBothResponse
?Q -> :msgQuery
?X -> :msgTerminate
?Z -> :msgReadyForQuery
?P -> :msgParse
?B -> :msgBind
_ -> :msgNoTag
end
end
end

Finally we will handle the parse result.

defmodule Statetrace.PostgresProxy do

...

def handle_parse(socket, outbound, {:continuation, continuation}) do
serve_upstream(socket, outbound, continuation, false)
end

def handle_parse(socket, outbound, {:ok, {{_msgType, c}, data}, left_over}) do
to_send =
case c do
nil -> data
_ -> [c, data]
end

:ok = :gen_tcp.send(outbound, to_send)

case left_over do
"" -> serve_upstream(socket, outbound, nil, false)
_ -> handle_parse(socket, outbound, parse_msg(left_over, false))
end
end
end

Serving downstream

The downstream response is even simpler. We will not parse the message and simply forward the data directly to the socket.

defmodule Statetrace.PostgresProxy do

...

def serve_downstream(socket, outbound) do
data = outbound |> read_line()
:ok = :gen_tcp.send(socket, data)

serve_downstream(socket, outbound)
end
end

Complete server

The complete server is less than 200 lines of code.

defmodule Statetrace.PostgresProxy do
require Logger

def accept(port) do
{:ok, socket} =
:gen_tcp.listen(port, [:binary, active: true, reuseaddr: true, packet: 0, nodelay: true])

Logger.info("Accepting connections on port #{port}")
loop_acceptor(socket)
end

defp loop_acceptor(socket) do
{:ok, client} = :gen_tcp.accept(socket)

{:ok, outbound} =
:gen_tcp.connect('localhost', 5432, [:binary, active: true, packet: 0, nodelay: true])

{:ok, pid} =
Task.Supervisor.start_child(Statetrace.TaskSupervisor, fn ->
serve_upstream(client, outbound, nil, true)
end)

:ok = :gen_tcp.controlling_process(client, pid)

{:ok, pid2} =
Task.Supervisor.start_child(Statetrace.TaskSupervisor, fn ->
serve_downstream(client, outbound)
end)

:ok = :gen_tcp.controlling_process(outbound, pid2)

loop_acceptor(socket)
end

defp serve_upstream(socket, outbound, nil, is_first) do
data = socket |> read_line()

handle_parse(socket, outbound, parse_msg(data, is_first))
end

defp serve_upstream(socket, outbound, fun, _is_first) do
data = socket |> read_line()

r = fun.(data)

handle_parse(socket, outbound, r)
end

defp handle_parse(socket, outbound, {:continuation, continuation}) do
serve_upstream(socket, outbound, continuation, false)
end

defp handle_parse(socket, outbound, {:ok, {{_msgType, c}, data}, left_over}) do
to_send =
case c do
nil -> data
_ -> [c, data]
end

:ok = :gen_tcp.send(outbound, to_send)

case left_over do
"" -> serve_upstream(socket, outbound, nil, false)
_ -> handle_parse(socket, outbound, parse_msg(left_over, false))
end
end

defp serve_downstream(socket, outbound) do
data = outbound |> read_line()
:ok = :gen_tcp.send(socket, data)

serve_downstream(socket, outbound)
end

defp read_line(socket) do
receive do
{:tcp, ^socket, data} -> data
end
end

defp tag_to_msg_type(val) do
case val do
?1 -> :msgParseComplete
?2 -> :msgBindComplete
?3 -> :msgCloseComplete
?A -> :msgNotificationResponse
?c -> :msgCopyDone
?C -> :msgCommandComplete
?d -> :msgCopyData
?D -> :msgDataRow
?E -> :msgErrorResponse
?f -> :msgFail
?G -> :msgCopyInResponse
?H -> :msgCopyOutResponse
?I -> :msgEmptyQueryResponse
?K -> :msgBackendKeyData
?n -> :msgNoData
?N -> :msgNoticeResponse
?R -> :msgAuthentication
?s -> :msgPortalSuspended
?S -> :msgParameterStatus
?t -> :msgParameterDescription
?T -> :msgRowDescription
?p -> :msgPasswordMessage
?W -> :CopyBothResponse
?Q -> :msgQuery
?X -> :msgTerminate
?Z -> :msgReadyForQuery
?P -> :msgParse
?B -> :msgBind
_ -> :msgNoTag
end
end

defp parse_msg(bin, true) do
<<len::unsigned-integer-32, _other_rest::binary>> = bin

case bin do
<<msg_body::binary-size(len), final_rest::binary>> ->
{:ok, {{:msgStartup, nil}, msg_body}, final_rest}

_ ->
{:continuation,
fn data ->
handle_continuation(len, {:msgStartup, nil}, bin, data)
end}
end
end

defp parse_msg(<<c::size(8), rest::binary>>, false) do
tag = tag_to_msg_type(c)

<<len::unsigned-integer-32, _other_rest::binary>> = rest

case rest do
<<msg_body::binary-size(len), final_rest::binary>> ->
{:ok, {{tag, c}, msg_body}, final_rest}

_ ->
{:continuation,
fn data ->
handle_continuation(len, {tag, c}, rest, data)
end}
end
end

defp handle_continuation(l, tag, other, data) do
new_data = other <> data

case new_data do
<<msg_body::binary-size(l), rest::binary>> ->
{:ok, {tag, msg_body}, rest}

_ ->
{:continuation,
fn data ->
handle_continuation(l, tag, new_data, data)
end}
end
end
end

To start it with your application, add it to your supervisor tree:

defmodule Statetrace.Application do
use Application

def start(_type, _args) do

children = [
{Task.Supervisor, name: Statetrace.TaskSupervisor},
Supervisor.child_spec({Task, fn -> Statetrace.PostgresProxy.accept(5433) end},
restart: :permanent
)

]
opts = [strategy: :one_for_one, name: Statetrace.Supervisor]
Supervisor.start_link(children, opts)
end
end

Connecting

Now you can connect to your proxy from postgres clients in every language!

Elixir:

{:ok, conn} = Postgrex.start_link(host: "localhost", port: 5433, password: "postgres", username: "postgres", database: "postgres")
Postgrex.query!(conn, "SELECT 1;", [])

In Python:

conn = psycopg2.connect("dbname=postgres host=localhost port=5433 user=postgres password=postgres")
cur = conn.cursor()
cur.execute("SELECT 1;")

Controlling Postgres Statements

So you're probabaly asking what can we do with this proxy? If we change our handle_parse function we can "police" the queries made to the upstream server. The format of the query message is simple, it is the query string you submitted prefixed with the length and suffixed with a null byte.

defmodule Statetrace.PostgresProxy do
...

defp is_query_good?(q) do
# In reality you will want something more secure than this
String.starts_with?(q, "SELECT")
end

# Use pattern matching to process :msgQuery differently.
# The query is prefixed with the length and suffixed with a null byte
defp handle_parse(socket, outbound, {:ok, {{:msgQuery, _c}, <<_len::unsigned-integer-32, data::binary>>}, _left_over} = msg) do
query = String.trim_trailing(data, <<0>>)
Logger.info("Query: #{query}")

if is_query_good?(query) do
do_handle_parse(socket, outbound, msg)
else
raise "Unauthorized Query"
end
end

defp handle_parse(socket, outbound, msg) do
do_handle_parse(socket, outbound, msg)
end

defp do_handle_parse(socket, outbound, {:ok, {{_msgType, c}, data}, left_over}) do
to_send =
case c do
nil -> data
_ -> [c, data]
end

:ok = :gen_tcp.send(outbound, to_send)

case left_over do
"" -> serve_upstream(socket, outbound, nil, false)
_ -> handle_parse(socket, outbound, parse_msg(left_over, false))
end
end

defp do_handle_parse(socket, outbound, {:continuation, continuation}) do
serve_upstream(socket, outbound, continuation, false)
end
end

In production you will want to set permissions at the database level specific to the user connected. However this proxy can still be useful at the connection level. You can do things like make sure that queries to a certain table filter on certain columns. For instance a SAAS platform might want to have an extra layer to make sure that all queries include an organization_id column in the WHERE clause.

Conclusion

There is still more to the protocol, but we have shown that with a little patience you can produce a useful proxy that understands the protocol in less than 200 lines of code. Elixir's pattern matching simplifies extracting information from the binary stream and gen_tcp makes it simple to deal with sockets.

Like what you read? Reach out at [email protected] to learn more about what we are working on.