diff --git a/cmd/connect.go b/cmd/connect.go index c797d1a..1eb30e3 100644 --- a/cmd/connect.go +++ b/cmd/connect.go @@ -30,8 +30,7 @@ func init() { } var connectCmd = &cobra.Command{ - Use: "connect", - + Use: "connect", Short: "Connect to a vpn server (survey selection appears if hostname is not provided)", Long: `Connect to a vpn from a list of relay servers`, Args: cobra.RangeArgs(0, 1), @@ -39,49 +38,43 @@ var connectCmd = &cobra.Command{ vpnServers, err := vpn.GetList(flagProxy, flagSocks5Proxy) if err != nil { log.Fatal().Msg(err.Error()) - os.Exit(1) } - serverSelection := []string{} - serverSelected := vpn.Server{} - - for _, s := range *vpnServers { - serverSelection = append(serverSelection, fmt.Sprintf("%s (%s)", s.HostName, s.CountryLong)) + // Build server selection options and hostname lookup map + serverSelection := make([]string, len(*vpnServers)) + serverMap := make(map[string]vpn.Server, len(*vpnServers)) + for i, s := range *vpnServers { + serverSelection[i] = fmt.Sprintf("%s (%s)", s.HostName, s.CountryLong) + serverMap[s.HostName] = s } selection := "" - prompt := &survey.Select{ - Message: "Choose a server:", - Options: serverSelection, - } + var serverSelected vpn.Server if !flagRandom { - if len(args) > 0 { selection = args[0] } else { + prompt := &survey.Select{ + Message: "Choose a server:", + Options: serverSelection, + } err := survey.AskOne(prompt, &selection, survey.WithPageSize(10)) if err != nil { - log.Error().Msg("Unable to obtain hostname from survey") - os.Exit(1) + log.Fatal().Msg("Unable to obtain hostname from survey") } } - // Server lookup from selection could be more optimized with a hash map - for _, s := range *vpnServers { - if strings.Contains(selection, s.HostName) { - serverSelected = s - } - } - - if serverSelected.HostName == "" { + // Lookup server from selection using map for O(1) lookup + hostname := extractHostname(selection) + if server, exists := serverMap[hostname]; exists { + serverSelected = server + } else { log.Fatal().Msgf("Server '%s' was not found", selection) - os.Exit(1) } } for { - if flagRandom { // Select a random server serverSelected = (*vpnServers)[rand.Intn(len(*vpnServers))] @@ -90,23 +83,19 @@ var connectCmd = &cobra.Command{ decodedConfig, err := base64.StdEncoding.DecodeString(serverSelected.OpenVpnConfigData) if err != nil { log.Fatal().Msg(err.Error()) - os.Exit(1) } tmpfile, err := os.CreateTemp("", "vpngate-openvpn-config-") if err != nil { log.Fatal().Msg(err.Error()) - os.Exit(1) } if _, err := tmpfile.Write(decodedConfig); err != nil { log.Fatal().Msg(err.Error()) - os.Exit(1) } if err := tmpfile.Close(); err != nil { log.Fatal().Msg(err.Error()) - os.Exit(1) } log.Info().Msgf("Connecting to %s (%s) in %s", serverSelected.HostName, serverSelected.IPAddr, serverSelected.CountryLong) @@ -114,16 +103,22 @@ var connectCmd = &cobra.Command{ err = vpn.Connect(tmpfile.Name()) if err != nil && !flagReconnect { - log.Fatal().Msg(err.Error()) - os.Exit(1) - } else { - err = os.Remove(tmpfile.Name()) - if err != nil { - log.Fatal().Msg(err.Error()) - os.Exit(1) - } + // VPN connection failed and reconnect is disabled + _ = os.Remove(tmpfile.Name()) + log.Fatal().Msg("VPN connection failed") } + // Always try to clean up temporary file + _ = os.Remove(tmpfile.Name()) } }, } + +// extractHostname extracts the hostname from the selection string (format: "hostname (country)") +func extractHostname(selection string) string { + parts := strings.Split(selection, " (") + if len(parts) > 0 { + return parts[0] + } + return selection +} \ No newline at end of file diff --git a/cmd/list.go b/cmd/list.go index 1fd0e8b..0f1ba75 100644 --- a/cmd/list.go +++ b/cmd/list.go @@ -27,7 +27,6 @@ var listCmd = &cobra.Command{ vpnServers, err := vpn.GetList(flagProxy, flagSocks5Proxy) if err != nil { log.Fatal().Msg(err.Error()) - os.Exit(1) } table := tw.NewWriter(os.Stdout) @@ -37,13 +36,11 @@ var listCmd = &cobra.Command{ err := table.Append([]string{strconv.Itoa(i + 1), v.HostName, v.CountryLong, v.Ping, strconv.Itoa(v.Score)}) if err != nil { log.Fatal().Msg(err.Error()) - os.Exit(1) } } err = table.Render() if err != nil { log.Fatal().Msg(err.Error()) - os.Exit(1) } }, -} +} \ No newline at end of file diff --git a/go.mod b/go.mod index ab55bbb..178d84c 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,6 @@ require ( github.com/juju/errors v1.0.0 github.com/olekukonko/tablewriter v1.1.3 github.com/rs/zerolog v1.34.0 - github.com/spf13/afero v1.15.0 github.com/spf13/cobra v1.10.2 github.com/stretchr/testify v1.11.1 golang.org/x/net v0.50.0 diff --git a/go.sum b/go.sum index 2a7f742..5a9eab7 100644 --- a/go.sum +++ b/go.sum @@ -66,8 +66,6 @@ github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= -github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= diff --git a/pkg/exec/run.go b/pkg/exec/run.go index 3851669..34af0b1 100644 --- a/pkg/exec/run.go +++ b/pkg/exec/run.go @@ -2,46 +2,58 @@ package exec import ( "bufio" - "os" + "io" "os/exec" - "strings" "github.com/rs/zerolog/log" ) -// Run executes a command in workDir and returns stdout and error. -// The spawned process will exit upon termination of this application -// to ensure a clean exit +// Run executes a command in workDir and logs its output. +// If the command fails to start or setup fails, an error is logged and returned. +// If the command exits with a non-zero status, the error is returned without logging +// (this allows the caller to decide how to handle it). func Run(path string, workDir string, args ...string) error { _, err := exec.LookPath(path) if err != nil { log.Error().Msgf("%s is required, please install it", path) - os.Exit(1) + return err } + cmd := exec.Command(path, args...) cmd.Dir = workDir - log.Debug().Msg("Executing " + strings.Join(cmd.Args, " ")) + + log.Debug().Strs("command", cmd.Args).Msg("Executing command") + stdout, err := cmd.StdoutPipe() if err != nil { - log.Fatal().Msgf("Failed to get stdout pipe: %v", err) + log.Error().Msgf("Failed to get stdout pipe: %v", err) + return err } + + stderr, err := cmd.StderrPipe() + if err != nil { + log.Error().Msgf("Failed to get stderr pipe: %v", err) + return err + } + if err := cmd.Start(); err != nil { - log.Fatal().Msgf("Failed to start command: %v", err) + log.Error().Msgf("Failed to start command: %v", err) + return err } - scanner := bufio.NewScanner(stdout) - + // Combine stdout and stderr into a single reader + combined := io.MultiReader(stdout, stderr) + scanner := bufio.NewScanner(combined) for scanner.Scan() { log.Debug().Msg(scanner.Text()) } if err := scanner.Err(); err != nil { - log.Fatal().Msgf("Error reading stdout: %v", err) - } - - if err := cmd.Wait(); err != nil { - log.Fatal().Msgf("Command finished with error: %v", err) + log.Error().Msgf("Error reading output: %v", err) return err } - return nil -} + + // cmd.Wait() returns an error if the command exits with non-zero status + // We return this without logging since it's expected behavior for some commands + return cmd.Wait() +} \ No newline at end of file diff --git a/pkg/util/retry.go b/pkg/util/retry.go index 0bb9a77..8ed3385 100644 --- a/pkg/util/retry.go +++ b/pkg/util/retry.go @@ -2,17 +2,18 @@ package util import ( "time" + "github.com/rs/zerolog/log" ) -func Retry(attempts int, delay time.Duration,fn func() error) error { +func Retry(attempts int, delay time.Duration, fn func() error) error { var err error for i := 0; i < attempts; i++ { if err = fn(); err == nil { return nil } - log.Error().Msgf("Retrying after %d seconds. An error occured: %s", delay, err) + log.Error().Msgf("Retrying after %v. An error occurred: %s", delay, err) time.Sleep(delay) } return err -} +} \ No newline at end of file diff --git a/pkg/vpn/cache.go b/pkg/vpn/cache.go index 10d6cea..2de0a53 100644 --- a/pkg/vpn/cache.go +++ b/pkg/vpn/cache.go @@ -4,37 +4,42 @@ import ( "encoding/json" "io" "os" - "path" + "path/filepath" "time" - - "github.com/rs/zerolog/log" - "github.com/spf13/afero" ) const serverCachefile = "servers.json" -func getCacheDir() string { +func getCacheDir() (string, error) { homeDir, err := os.UserHomeDir() if err != nil { - log.Error().Msgf("Failed to get user's home directory: %s ", err) - return "" + return "", err } - cacheDir := path.Join(homeDir, ".vpngate", "cache") - return cacheDir + cacheDir := filepath.Join(homeDir, ".vpngate", "cache") + return cacheDir, nil } func createCacheDir() error { - cacheDir := getCacheDir() - AppFs := afero.NewOsFs() - return AppFs.MkdirAll(cacheDir, 0o700) + cacheDir, err := getCacheDir() + if err != nil { + return err + } + return os.MkdirAll(cacheDir, 0o700) } func getVpnListCache() (*[]Server, error) { - cacheFile := path.Join(getCacheDir(), serverCachefile) + cacheDir, err := getCacheDir() + if err != nil { + return nil, err + } + cacheFile := filepath.Join(cacheDir, serverCachefile) serversFile, err := os.Open(cacheFile) if err != nil { return nil, err } + defer func() { + _ = serversFile.Close() + }() byteValue, err := io.ReadAll(serversFile) if err != nil { @@ -44,7 +49,6 @@ func getVpnListCache() (*[]Server, error) { var servers []Server err = json.Unmarshal(byteValue, &servers) - if err != nil { return nil, err } @@ -53,8 +57,7 @@ func getVpnListCache() (*[]Server, error) { } func writeVpnListToCache(servers []Server) error { - err := createCacheDir() - if err != nil { + if err := createCacheDir(); err != nil { return err } @@ -63,20 +66,26 @@ func writeVpnListToCache(servers []Server) error { return err } - cacheFile := path.Join(getCacheDir(), serverCachefile) + cacheDir, err := getCacheDir() + if err != nil { + return err + } + cacheFile := filepath.Join(cacheDir, serverCachefile) - err = os.WriteFile(cacheFile, f, 0o644) - - return err + return os.WriteFile(cacheFile, f, 0o644) } func vpnListCacheIsExpired() bool { - file, err := os.Stat(path.Join(getCacheDir(), serverCachefile)) + cacheDir, err := getCacheDir() + if err != nil { + return true + } + file, err := os.Stat(filepath.Join(cacheDir, serverCachefile)) if err != nil { return true } lastModified := file.ModTime() - return (time.Since(lastModified)) > time.Duration(24*time.Hour) -} + return time.Since(lastModified) > 24*time.Hour +} \ No newline at end of file diff --git a/pkg/vpn/client.go b/pkg/vpn/client.go index 3437620..bc25e17 100644 --- a/pkg/vpn/client.go +++ b/pkg/vpn/client.go @@ -1,28 +1,17 @@ package vpn import ( - "os" "runtime" "github.com/davegallant/vpngate/pkg/exec" - "github.com/juju/errors" ) // Connect to a specified OpenVPN configuration func Connect(configPath string) error { - tmpLogFile, err := os.CreateTemp("", "vpngate-openvpn-log-") - if err != nil { - return errors.Annotate(err, "Unable to create a temporary log file") - } - defer func() { - _ = os.Remove(tmpLogFile.Name()) - }() - executable := "openvpn" if runtime.GOOS == "windows" { executable = "C:\\Program Files\\OpenVPN\\bin\\openvpn.exe" } - err = exec.Run(executable, ".", "--verb", "4", "--config", configPath, "--data-ciphers", "AES-128-CBC") - return err -} + return exec.Run(executable, ".", "--verb", "4", "--config", configPath, "--data-ciphers", "AES-128-CBC") +} \ No newline at end of file diff --git a/pkg/vpn/list.go b/pkg/vpn/list.go index a35a41e..aad493a 100644 --- a/pkg/vpn/list.go +++ b/pkg/vpn/list.go @@ -2,10 +2,12 @@ package vpn import ( "bytes" + "context" "io" + "net" "net/http" "net/url" - "os" + "time" "github.com/jszwec/csvutil" "github.com/rs/zerolog/log" @@ -16,10 +18,12 @@ import ( ) const ( - vpnList = "https://www.vpngate.net/api/iphone/" + vpnList = "https://www.vpngate.net/api/iphone/" + httpClientTimeout = 30 * time.Second + dialTimeout = 10 * time.Second ) -// Server holds in formation about a vpn relay server +// Server holds information about a vpn relay server type Server struct { HostName string `csv:"#HostName"` CountryLong string `csv:"CountryLong"` @@ -30,20 +34,14 @@ type Server struct { Ping string `csv:"Ping"` } -func streamToBytes(stream io.Reader) []byte { - buf := new(bytes.Buffer) - _, err := buf.ReadFrom(stream) - if err != nil { - log.Error().Msg("Unable to stream bytes") - } - return buf.Bytes() -} - -// parse csv +// parseVpnList parses the VPN server list from CSV format func parseVpnList(r io.Reader) (*[]Server, error) { var servers []Server - serverList := streamToBytes(r) + serverList, err := io.ReadAll(r) + if err != nil { + return nil, errors.Annotate(err, "Unable to read stream") + } // Trim known invalid rows serverList = bytes.TrimPrefix(serverList, []byte("*vpn_servers\r\n")) @@ -57,86 +55,119 @@ func parseVpnList(r io.Reader) (*[]Server, error) { return &servers, nil } +// createHTTPClient creates an HTTP client with optional proxy configuration +func createHTTPClient(httpProxy string, socks5Proxy string) (*http.Client, error) { + if httpProxy != "" { + proxyURL, err := url.Parse(httpProxy) + if err != nil { + return nil, errors.Annotatef(err, "Error parsing HTTP proxy: %s", httpProxy) + } + transport := &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + } + return &http.Client{ + Transport: transport, + Timeout: httpClientTimeout, + }, nil + } + + if socks5Proxy != "" { + dialer, err := proxy.SOCKS5("tcp", socks5Proxy, nil, proxy.Direct) + if err != nil { + return nil, errors.Annotatef(err, "Error creating SOCKS5 dialer: %v", err) + } + + // Create a DialContext function from the SOCKS5 dialer + dialContext := func(ctx context.Context, network, addr string) (net.Conn, error) { + // Check if context is already done + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + // Use the dialer with a timeout + conn, err := dialer.Dial(network, addr) + if err != nil { + return nil, err + } + + // Respect context cancellation after connection + go func() { + <-ctx.Done() + _ = conn.Close() + }() + + return conn, nil + } + + httpTransport := &http.Transport{ + DialContext: dialContext, + } + return &http.Client{ + Transport: httpTransport, + Timeout: httpClientTimeout, + }, nil + } + + return &http.Client{ + Timeout: httpClientTimeout, + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: dialTimeout, + }).DialContext, + }, + }, nil +} + // GetList returns a list of vpn servers func GetList(httpProxy string, socks5Proxy string) (*[]Server, error) { cacheExpired := vpnListCacheIsExpired() - var servers *[]Server - var client *http.Client - + // Try to use cached list if not expired if !cacheExpired { servers, err := getVpnListCache() - - if err != nil { - log.Info().Msg("Unable to retrieve vpn list from cache") - } else { + if err == nil { return servers, nil } - + log.Info().Msg("Unable to retrieve vpn list from cache") } else { log.Info().Msg("The vpn server list cache has expired") } log.Info().Msg("Fetching the latest server list") - if httpProxy != "" { - proxyURL, err := url.Parse(httpProxy) - if err != nil { - log.Error().Msgf("Error parsing proxy: %s", err) - os.Exit(1) - } - transport := &http.Transport{ - Proxy: http.ProxyURL(proxyURL), - } - - client = &http.Client{ - Transport: transport, - } - - } else if socks5Proxy != "" { - dialer, err := proxy.SOCKS5("tcp", socks5Proxy, nil, proxy.Direct) - if err != nil { - log.Error().Msgf("Error creating SOCKS5 dialer: %v", err) - os.Exit(1) - } - - httpTransport := &http.Transport{ - Dial: dialer.Dial, - } - - client = &http.Client{ - Transport: httpTransport, - } - } else { - client = &http.Client{} + client, err := createHTTPClient(httpProxy, socks5Proxy) + if err != nil { + return nil, err } - var r *http.Response + var servers *[]Server - err := util.Retry(5, 1, func() error { - var err error - r, err = client.Get(vpnList) + err = util.Retry(5, 1, func() error { + resp, err := client.Get(vpnList) if err != nil { return err } defer func() { - _ = r.Body.Close() + _ = resp.Body.Close() }() - if r.StatusCode != 200 { - return errors.Annotatef(err, "Unexpected status code when retrieving vpn list: %d", r.StatusCode) + if resp.StatusCode != http.StatusOK { + return errors.Annotatef(err, "Unexpected status code when retrieving vpn list: %d", resp.StatusCode) } - servers, err = parseVpnList(r.Body) - + parsedServers, err := parseVpnList(resp.Body) if err != nil { return err } - err = writeVpnListToCache(*servers) + servers = parsedServers - if err != nil { - log.Warn().Msgf("Unable to write servers to cache: %s", err) + // Cache the servers for future use + cacheErr := writeVpnListToCache(*servers) + if cacheErr != nil { + log.Warn().Msgf("Unable to write servers to cache: %s", cacheErr) } return nil }) @@ -146,4 +177,4 @@ func GetList(httpProxy string, socks5Proxy string) (*[]Server, error) { } return servers, nil -} +} \ No newline at end of file