Passing file descriptor between processes


Here shows how to pass file descriptors between processes in go.

package main

import (
        "fmt"
        "log"
        "net"
        "os"
        "os/exec"
        "syscall"
)

func main() {
        log.SetFlags(log.LstdFlags | log.Lshortfile)
        if os.Getenv("SUBPROCESS") == "true" {
                subProcess()
                return
        }
        mainProcess()
}

func mainProcess() {
        localFile, remoteFile, err := unixPair()
        if err != nil {
                log.Fatal(err)
        }
        defer localFile.Close()
        defer remoteFile.Close()
        cmd := exec.Command(os.Args[0], os.Args[1:]...)
        cmd.Stdin = os.Stdin
        cmd.Stdout = os.Stdout
        cmd.Stderr = os.Stderr
        cmd.ExtraFiles = []*os.File{remoteFile}
        cmd.Env = append(os.Environ(), "SUBPROCESS=true")
        err = cmd.Start()
        if err != nil {
                log.Fatal(err)
        }
        local, err := net.FileConn(localFile)
        if err != nil {
                log.Fatal(err)
        }
        defer local.Close()
        unixConn, ok := local.(*net.UnixConn)
        if !ok {
                log.Fatal("not a unix conn")
        }
        file, err := recvFd(unixConn)
        if err != nil {
                log.Fatal(err)
        }
        defer file.Close()
        fmt.Println("the fd in main process is ", file.Fd())
        fmt.Fprintln(file, "hello from main process")
        err = cmd.Wait()
        if err != nil {
                log.Println(err)
        }
}

func subProcess() {
        file, err := os.Create("hello.txt")
        if err != nil {
                log.Fatal(err)
        }
        defer file.Close()
        fmt.Println("the fd in sub process is ", file.Fd())
        fmt.Fprintln(file, "hello from sub process")
        remoteFile := os.NewFile(3, "unix-remote")
        defer remoteFile.Close()
        conn, err := net.FileConn(remoteFile)
        if err != nil {
                log.Fatal(err)
        }
        defer conn.Close()
        unixConn, ok := conn.(*net.UnixConn)
        if !ok {
                log.Fatal("sub process: ", "not a unix conn")
        }
        err = sendFd(unixConn, file)
        if err != nil {
                log.Fatal(err)
        }
}

func unixPair() (*os.File, *os.File, error) {
        fd, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
        if err != nil {
                return nil, nil, err
        }
        localFile := os.NewFile(uintptr(fd[0]), "unix-local")
        remoteFile := os.NewFile(uintptr(fd[1]), "unix-remote")
        return localFile, remoteFile, nil
}

func sendFd(unixLocal *net.UnixConn, file *os.File) error {
        oob := syscall.UnixRights(int(file.Fd()))
        _, _, err := unixLocal.WriteMsgUnix(nil, oob, nil)
        return err
}

func recvFd(unixConn *net.UnixConn) (*os.File, error) {
        var (
                b   [32]byte
                oob [32]byte
        )
        _, oobn, _, _, err := unixConn.ReadMsgUnix(b[:], oob[:])
        if err != nil {
                return nil, err
        }
        messages, err := syscall.ParseSocketControlMessage(oob[:oobn])
        if err != nil {
                return nil, err
        }
        if len(messages) != 1 {
                return nil, fmt.Errorf("expect 1 message, got %#v", messages)
        }
        message := messages[0]
        fds, err := syscall.ParseUnixRights(&message)
        if err != nil {
                return nil, err
        }
        if len(fds) != 1 {
                return nil, fmt.Errorf("expect 1 fd, got %#v", fds)
        }
        return os.NewFile(uintptr(fds[0]), "remote-file"), nil
}