aboutsummaryrefslogtreecommitdiff
path: root/main.go
diff options
context:
space:
mode:
Diffstat (limited to 'main.go')
-rw-r--r--main.go204
1 files changed, 204 insertions, 0 deletions
diff --git a/main.go b/main.go
new file mode 100644
index 0000000..f85e753
--- /dev/null
+++ b/main.go
@@ -0,0 +1,204 @@
+package main
+
+import (
+ "compress/gzip"
+ "flag"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "net/url"
+ "os"
+ "regexp"
+ "strings"
+ "time"
+)
+
+var (
+ input string
+ timeout int
+ headers headerFlags
+ userAgent string
+ depth int
+ delay int
+ httpClient = &http.Client{}
+ visited = make(map[string]bool)
+)
+
+type headerFlags []string
+
+func (h *headerFlags) String() string {
+ return strings.Join(*h, ", ")
+}
+
+func (h *headerFlags) Set(value string) error {
+ *h = append(*h, value)
+ return nil
+}
+
+func init() {
+ flag.StringVar(&input, "input", "", "url or file path")
+ flag.IntVar(&timeout, "timeout", 10, "timeout for http requests in seconds")
+ flag.StringVar(&userAgent, "user-agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.10 Safari/605.1.1", "set user-agent")
+ flag.Var(&headers, "header", "add http header to request (e.g. -header \"Authorization: Bearer val1\")")
+ flag.IntVar(&depth, "depth", 0, "recursion depth for same-domain links (0 disables crawling)")
+ flag.IntVar(&delay, "delay", 4, "delay between requests in seconds when crawling (only applies if depth > 0)")
+ flag.Parse()
+
+ if input == "" {
+ fmt.Printf("[err] input is required. use -input <url|file>\n")
+ os.Exit(1)
+ }
+ httpClient.Timeout = time.Duration(timeout) * time.Second
+}
+
+var defaultRegex = regexp.MustCompile(`(?:"|')((?:[a-zA-Z]{1,10}://|//)[^"'/]+\.[a-zA-Z]{2,}[^"']*|(?:/|\.\./|\./)[^"'><,;|()*\[\]\s][^"'><,;|()]{1,}|[a-zA-Z0-9_\-/]+/[a-zA-Z0-9_\-/.]+\.(?:php|asp|aspx|jsp|json|action|html|js|txt|xml)(?:[\?|#][^"']*)?)["']`)
+
+func main() {
+ sourceType, err := resolveInput(input)
+ if err != nil {
+ fmt.Printf("[err] %v\n", err)
+ os.Exit(1)
+ }
+
+ if sourceType == "url" {
+ baseURL, _ := url.Parse(input)
+ crawl(input, baseURL, depth)
+ } else {
+ content, err := fetchContent("file", input)
+ if err != nil {
+ fmt.Printf("[err] failed to fetch %s: %v\n", input, err)
+ os.Exit(1)
+ }
+ matches := parseContent(content, defaultRegex)
+ printMatches(input, matches)
+ }
+}
+
+func resolveInput(input string) (string, error) {
+ if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") {
+ return "url", nil
+ }
+ if info, err := os.Stat(input); err == nil && !info.IsDir() {
+ return "file", nil
+ }
+ return "", fmt.Errorf("input must be a valid url or file path")
+}
+
+func fetchContent(sourceType, target string) (string, error) {
+ if sourceType == "url" {
+ if visited[target] {
+ return "", nil
+ }
+ visited[target] = true
+
+ req, err := http.NewRequest("GET", target, nil)
+ if err != nil {
+ return "", err
+ }
+ req.Header.Set("User-Agent", userAgent)
+ req.Header.Set("Accept-Encoding", "gzip")
+ for _, h := range headers {
+ parts := strings.SplitN(h, ":", 2)
+ if len(parts) == 2 {
+ req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
+ }
+ }
+ resp, err := httpClient.Do(req)
+ if err != nil {
+ return "", err
+ }
+ defer resp.Body.Close()
+
+ var reader io.ReadCloser
+ switch resp.Header.Get("Content-Encoding") {
+ case "gzip":
+ reader, err = gzip.NewReader(resp.Body)
+ if err != nil {
+ return "", err
+ }
+ defer reader.Close()
+ default:
+ reader = resp.Body
+ }
+ body, err := io.ReadAll(reader)
+ if err != nil {
+ return "", err
+ }
+ return string(body), nil
+ }
+
+ f, err := os.Open(target)
+ if err != nil {
+ return "", err
+ }
+ defer f.Close()
+ content, err := ioutil.ReadAll(f)
+ return string(content), err
+}
+
+func parseContent(content string, re *regexp.Regexp) []string {
+ seen := make(map[string]bool)
+ results := []string{}
+ for _, match := range re.FindAllStringSubmatch(content, -1) {
+ link := match[1]
+ if !seen[link] {
+ seen[link] = true
+ results = append(results, link)
+ }
+ }
+ return results
+}
+
+func printMatches(source string, matches []string) {
+ fmt.Printf("[inf] %s\n", strings.ToLower(source))
+ for _, match := range matches {
+ fmt.Println(strings.ToLower(match))
+ }
+ fmt.Println()
+}
+
+func crawl(target string, base *url.URL, maxDepth int) {
+ if maxDepth < 0 {
+ return
+ }
+ content, err := fetchContent("url", target)
+ if err != nil || content == "" {
+ fmt.Printf("[err] failed to fetch %s: %v\n", target, err)
+ return
+ }
+ matches := parseContent(content, defaultRegex)
+ printMatches(target, matches)
+
+ for _, match := range matches {
+ u := resolveURL(base, match)
+ if u == "" {
+ continue
+ }
+ parsed, err := url.Parse(u)
+ if err != nil || parsed.Host != base.Host {
+ continue
+ }
+
+ if maxDepth > 0 && delay > 0 {
+ time.Sleep(time.Duration(delay) * time.Second)
+ }
+
+ crawl(u, base, maxDepth-1)
+ }
+}
+
+func resolveURL(base *url.URL, href string) string {
+ href = strings.TrimSpace(href)
+ if strings.HasPrefix(href, "//") {
+ return base.Scheme + ":" + href
+ }
+ if strings.HasPrefix(href, "http://") || strings.HasPrefix(href, "https://") {
+ return href
+ }
+ ref, err := url.Parse(href)
+ if err != nil {
+ return ""
+ }
+ return base.ResolveReference(ref).String()
+}