Files
NexaVPN/deploy/access-proxy/main.go
nessi a8a88140af refactor: replace Peek with ReadFull in TLS ClientHello parsing to prevent buffering issues
Replace bufio.Reader.Peek calls with io.ReadFull for header and record body reading. Allocate header and full buffers explicitly and copy header into full buffer before reading remaining bytes. Remove redundant byte slice copy when returning full ClientHello data.
2026-03-19 22:38:12 +01:00

330 lines
8.0 KiB
Go

package main
import (
"bufio"
"context"
"encoding/json"
"errors"
"io"
"log"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
type syncBundle struct {
Peers []peerConfig `json:"peers"`
}
type peerConfig struct {
AssignedIP string `json:"assigned_ip"`
AllowedServices []allowedService `json:"allowed_services"`
}
type allowedService struct {
Name string `json:"name"`
Domain string `json:"domain"`
ProxyIP string `json:"proxy_ip"`
AccessProxyIP string `json:"access_proxy_ip"`
}
type proxyState struct {
mu sync.RWMutex
allowed map[string]map[string]allowedService
}
func main() {
state := &proxyState{
allowed: make(map[string]map[string]allowedService),
}
ctx := context.Background()
if err := refreshConfig(ctx, state); err != nil {
log.Printf("initial config refresh failed: %v", err)
}
go func() {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
for range ticker.C {
if err := refreshConfig(ctx, state); err != nil {
log.Printf("config refresh failed: %v", err)
}
}
}()
httpAddr := envOrDefault("NEXAVPN_ACCESS_PROXY_HTTP_ADDR", ":8088")
httpsAddr := envOrDefault("NEXAVPN_ACCESS_PROXY_HTTPS_ADDR", ":8448")
go func() {
log.Printf("HTTP access proxy listening on %s", httpAddr)
if err := http.ListenAndServe(httpAddr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleHTTP(state, w, r)
})); err != nil {
log.Fatalf("http server failed: %v", err)
}
}()
log.Printf("HTTPS access proxy listening on %s", httpsAddr)
if err := serveTLSProxy(httpsAddr, state); err != nil {
log.Fatalf("https proxy failed: %v", err)
}
}
func handleHTTP(state *proxyState, w http.ResponseWriter, r *http.Request) {
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
http.Error(w, "invalid client address", http.StatusForbidden)
return
}
host := normalizeHost(r.Host)
service, ok := state.lookup(clientIP, host)
if !ok {
http.Error(w, "service not allowed", http.StatusForbidden)
return
}
targetURL := &url.URL{
Scheme: "http",
Host: net.JoinHostPort(service.ProxyIP, "80"),
}
proxy := httputil.NewSingleHostReverseProxy(targetURL)
originalDirector := proxy.Director
proxy.Director = func(req *http.Request) {
originalDirector(req)
req.Host = host
req.Header.Set("Host", host)
req.Header.Set("X-Forwarded-Host", host)
}
proxy.ErrorHandler = func(rw http.ResponseWriter, _ *http.Request, proxyErr error) {
http.Error(rw, proxyErr.Error(), http.StatusBadGateway)
}
proxy.ServeHTTP(w, r)
}
func serveTLSProxy(addr string, state *proxyState) error {
listener, err := net.Listen("tcp", addr)
if err != nil {
return err
}
defer listener.Close()
for {
conn, err := listener.Accept()
if err != nil {
return err
}
go handleTLSConn(conn, state)
}
}
func handleTLSConn(clientConn net.Conn, state *proxyState) {
defer clientConn.Close()
clientIP, _, err := net.SplitHostPort(clientConn.RemoteAddr().String())
if err != nil {
return
}
reader := bufio.NewReader(clientConn)
hello, host, err := readClientHello(reader)
if err != nil {
return
}
service, ok := state.lookup(clientIP, host)
if !ok {
return
}
upstreamConn, err := net.DialTimeout("tcp", net.JoinHostPort(service.ProxyIP, "443"), 10*time.Second)
if err != nil {
return
}
defer upstreamConn.Close()
if _, err := upstreamConn.Write(hello); err != nil {
return
}
errCh := make(chan error, 2)
go proxyCopy(errCh, upstreamConn, reader)
go proxyCopy(errCh, clientConn, upstreamConn)
<-errCh
}
func proxyCopy(errCh chan<- error, dst io.Writer, src io.Reader) {
_, err := io.Copy(dst, src)
errCh <- err
}
func readClientHello(reader *bufio.Reader) ([]byte, string, error) {
header := make([]byte, 5)
if _, err := io.ReadFull(reader, header); err != nil {
return nil, "", err
}
if header[0] != 22 {
return nil, "", errors.New("not a tls client hello")
}
recordLen := int(header[3])<<8 | int(header[4])
full := make([]byte, 5+recordLen)
copy(full, header)
if _, err := io.ReadFull(reader, full[5:]); err != nil {
return nil, "", err
}
host, err := extractSNI(full)
if err != nil {
return nil, "", err
}
return full, host, nil
}
func extractSNI(packet []byte) (string, error) {
if len(packet) < 43 {
return "", errors.New("tls packet too short")
}
sessionIDLen := int(packet[43])
offset := 44 + sessionIDLen
if len(packet) < offset+2 {
return "", errors.New("missing cipher suites")
}
cipherLen := int(packet[offset])<<8 | int(packet[offset+1])
offset += 2 + cipherLen
if len(packet) < offset+1 {
return "", errors.New("missing compression methods")
}
compressionLen := int(packet[offset])
offset += 1 + compressionLen
if len(packet) < offset+2 {
return "", errors.New("missing extensions")
}
extensionsLen := int(packet[offset])<<8 | int(packet[offset+1])
offset += 2
end := offset + extensionsLen
if len(packet) < end {
return "", errors.New("invalid extensions length")
}
for offset+4 <= end {
extensionType := int(packet[offset])<<8 | int(packet[offset+1])
extensionLen := int(packet[offset+2])<<8 | int(packet[offset+3])
offset += 4
if offset+extensionLen > end {
return "", errors.New("invalid extension")
}
if extensionType == 0 {
if extensionLen < 5 {
return "", errors.New("invalid sni extension")
}
serverNameLen := int(packet[offset+3])<<8 | int(packet[offset+4])
if offset+5+serverNameLen > end {
return "", errors.New("invalid sni length")
}
return normalizeHost(string(packet[offset+5 : offset+5+serverNameLen])), nil
}
offset += extensionLen
}
return "", errors.New("sni not found")
}
func refreshConfig(ctx context.Context, state *proxyState) error {
gatewayID, err := resolveGatewayID()
if err != nil {
return err
}
syncURL := strings.TrimRight(envOrDefault("NEXAVPN_GATEWAY_SYNC_URL", "http://127.0.0.1:8080/api/v1/gateway-agent"), "/") + "/" + gatewayID + "/sync"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, syncURL, nil)
if err != nil {
return err
}
req.Header.Set("X-Gateway-Bootstrap-Token", os.Getenv("GATEWAY_BOOTSTRAP_TOKEN"))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return errors.New("sync request failed with status " + resp.Status)
}
var bundle syncBundle
if err := json.NewDecoder(resp.Body).Decode(&bundle); err != nil {
return err
}
allowed := make(map[string]map[string]allowedService)
for _, peer := range bundle.Peers {
hostMap := make(map[string]allowedService)
for _, service := range peer.AllowedServices {
hostMap[normalizeHost(service.Domain)] = service
}
allowed[stripCIDR(peer.AssignedIP)] = hostMap
}
state.mu.Lock()
state.allowed = allowed
state.mu.Unlock()
return nil
}
func (s *proxyState) lookup(clientIP string, host string) (allowedService, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
services, ok := s.allowed[clientIP]
if !ok {
return allowedService{}, false
}
service, ok := services[normalizeHost(host)]
return service, ok
}
func resolveGatewayID() (string, error) {
if value := strings.TrimSpace(os.Getenv("NEXAVPN_GATEWAY_ID")); value != "" {
return value, nil
}
stateFile := envOrDefault("NEXAVPN_GATEWAY_ID_FILE", "/var/lib/nexavpn/gateway-id")
raw, err := os.ReadFile(filepath.Clean(stateFile))
if err != nil {
return "", err
}
return strings.TrimSpace(string(raw)), nil
}
func envOrDefault(key string, fallback string) string {
if value := strings.TrimSpace(os.Getenv(key)); value != "" {
return value
}
return fallback
}
func stripCIDR(value string) string {
if index := strings.IndexByte(value, '/'); index >= 0 {
return value[:index]
}
return value
}
func normalizeHost(host string) string {
host = strings.TrimSpace(strings.ToLower(host))
host = strings.TrimSuffix(host, ".")
if strings.Contains(host, ":") {
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
host = parsedHost
}
}
return host
}