aboutsummaryrefslogtreecommitdiff
path: root/pkg
diff options
context:
space:
mode:
authorBryan McNulty <bryanmcnulty@protonmail.com>2025-04-20 18:23:36 -0500
committerBryan McNulty <bryanmcnulty@protonmail.com>2025-04-20 18:23:36 -0500
commit1168c8657117cb72426e9e2bfc68bf8ae9575bb1 (patch)
treeb6735b553e80719ccf453bde8db694e192bac8ee /pkg
parent6ade3ddd945e50d7a145294ac4681489be5d22f8 (diff)
downloadgoexec-1168c8657117cb72426e9e2bfc68bf8ae9575bb1.tar.gz
goexec-1168c8657117cb72426e9e2bfc68bf8ae9575bb1.zip
Improve smb.OutputFileFetcher; introduce stage input
Diffstat (limited to 'pkg')
-rw-r--r--pkg/goexec/io.go9
-rw-r--r--pkg/goexec/method.go12
-rw-r--r--pkg/goexec/smb/client.go159
-rw-r--r--pkg/goexec/smb/input.go66
-rw-r--r--pkg/goexec/smb/output.go34
5 files changed, 190 insertions, 90 deletions
diff --git a/pkg/goexec/io.go b/pkg/goexec/io.go
index ab2b704..1d4358f 100644
--- a/pkg/goexec/io.go
+++ b/pkg/goexec/io.go
@@ -27,7 +27,7 @@ type ExecutionOutput struct {
}
type ExecutionInput struct {
- FilePath string
+ StageFile io.ReadCloser
Executable string
ExecutablePath string
Arguments string
@@ -89,3 +89,10 @@ func (i *ExecutionInput) CommandLine() (cmd []string) {
func (i *ExecutionInput) String() string {
return strings.Join(i.CommandLine(), " ")
}
+
+func (i *ExecutionInput) Reader() (reader io.Reader) {
+ if i.StageFile != nil {
+ return i.StageFile
+ }
+ return strings.NewReader(i.String())
+}
diff --git a/pkg/goexec/method.go b/pkg/goexec/method.go
index e57442f..88898f2 100644
--- a/pkg/goexec/method.go
+++ b/pkg/goexec/method.go
@@ -102,17 +102,15 @@ func ExecuteCleanAuxiliaryMethod(ctx context.Context, module CleanAuxiliaryMetho
func ExecuteCleanMethod(ctx context.Context, module CleanExecutionMethod, execIO *ExecutionIO) (err error) {
log := zerolog.Ctx(ctx)
- defer func() {
- if err = module.Clean(ctx); err != nil {
- log.Error().Err(err).Msg("Module cleanup failed")
- err = nil
- }
- }()
-
if err = ExecuteMethod(ctx, module, execIO); err != nil {
return
}
+ if err = module.Clean(ctx); err != nil {
+ log.Error().Err(err).Msg("Module cleanup failed")
+ err = nil
+ }
+
if execIO.Output != nil && execIO.Output.Provider != nil {
log.Info().Msg("Collecting output")
diff --git a/pkg/goexec/smb/client.go b/pkg/goexec/smb/client.go
index 3b41e39..d95481c 100644
--- a/pkg/goexec/smb/client.go
+++ b/pkg/goexec/smb/client.go
@@ -1,112 +1,123 @@
package smb
import (
- "context"
- "errors"
- "fmt"
- "github.com/oiweiwei/go-smb2.fork"
- "github.com/rs/zerolog"
- "net"
+ "context"
+ "errors"
+ "fmt"
+ "github.com/oiweiwei/go-smb2.fork"
+ "github.com/rs/zerolog"
+ "net"
)
type Client struct {
- ClientOptions
+ ClientOptions
- conn net.Conn
- sess *smb2.Session
- mount *smb2.Share
+ conn net.Conn
+ sess *smb2.Session
+ mount *smb2.Share
+
+ connected bool
+ share string
}
func (c *Client) Session() (sess *smb2.Session) {
- return c.sess
+ return c.sess
}
func (c *Client) String() string {
- return ClientName
+ return ClientName
}
func (c *Client) Logger(ctx context.Context) zerolog.Logger {
- return zerolog.Ctx(ctx).With().Str("client", c.String()).Logger()
+ return zerolog.Ctx(ctx).With().Str("client", c.String()).Logger()
}
func (c *Client) Mount(ctx context.Context, share string) (err error) {
- if c.sess == nil {
- return errors.New("SMB session not initialized")
- }
+ if c.sess == nil {
+ return errors.New("SMB session not initialized")
+ }
- c.mount, err = c.sess.Mount(share)
- zerolog.Ctx(ctx).Debug().Str("share", share).Msg("Mounted SMB share")
+ c.mount, err = c.sess.Mount(share)
+ zerolog.Ctx(ctx).Debug().Str("share", share).Msg("Mounted SMB share")
+ c.share = share
- return
+ return
}
func (c *Client) Connect(ctx context.Context) (err error) {
- log := c.Logger(ctx)
- {
- if c.netDialer == nil {
- panic(fmt.Errorf("TCP dialer not initialized"))
- }
- if c.dialer == nil {
- panic(fmt.Errorf("%s dialer not initialized", c.String()))
- }
- }
+ log := c.Logger(ctx)
+ {
+ if c.netDialer == nil {
+ panic(fmt.Errorf("TCP dialer not initialized"))
+ }
+ if c.dialer == nil {
+ panic(fmt.Errorf("%s dialer not initialized", c.String()))
+ }
+ }
- // Establish TCP connection
- c.conn, err = c.netDialer.Dial("tcp", net.JoinHostPort(c.Host, fmt.Sprintf("%d", c.Port)))
+ // Establish TCP connection
+ c.conn, err = c.netDialer.Dial("tcp", net.JoinHostPort(c.Host, fmt.Sprintf("%d", c.Port)))
- if err != nil {
- return err
- }
+ if err != nil {
+ return err
+ }
- log = log.With().Str("address", c.conn.RemoteAddr().String()).Logger()
- log.Debug().Msgf("Connected to %s server", c.String())
+ log = log.With().Str("address", c.conn.RemoteAddr().String()).Logger()
+ log.Debug().Msgf("Connected to %s server", c.String())
- // Open SMB session
- c.sess, err = c.dialer.DialContext(ctx, c.conn)
+ // Open SMB session
+ c.sess, err = c.dialer.DialContext(ctx, c.conn)
- if err != nil {
- log.Error().Err(err).Msgf("Failed to open %s session", c.String())
- return fmt.Errorf("dial %s: %w", c.String(), err)
- }
+ if err != nil {
+ log.Error().Err(err).Msgf("Failed to open %s session", c.String())
+ return fmt.Errorf("dial %s: %w", c.String(), err)
+ }
+ log.Debug().Msgf("Opened %s session", c.String())
- log.Debug().Msgf("Opened %s session", c.String())
+ c.connected = true
- return
+ return
}
func (c *Client) Close(ctx context.Context) (err error) {
- log := c.Logger(ctx)
-
- // Close SMB session
- if c.sess != nil {
- defer func() {
- if err = c.sess.Logoff(); err != nil {
- log.Debug().Err(err).Msgf("Failed to discard SMB session")
- }
- log.Debug().Msgf("Discarded SMB session")
- }()
-
- } else if c.conn != nil {
-
- defer func() {
- if err = c.conn.Close(); err != nil {
- log.Debug().Err(err).Msgf("Failed to disconnect SMB client")
- }
- log.Debug().Msgf("Disconnected SMB session")
- }()
- }
-
- // Unmount SMB share
- if c.mount != nil {
- defer func() {
- if err = c.mount.Umount(); err != nil {
- log.Debug().Err(err).Msg("Failed to unmount share")
- }
- log.Debug().Msg("Unmounted file share")
- }()
- }
- return
+ log := c.Logger(ctx)
+
+ c.connected = false
+
+ // Close SMB session
+ if c.sess != nil {
+ defer func() {
+ if err = c.sess.Logoff(); err != nil {
+ log.Debug().Err(err).Msgf("Failed to discard SMB session")
+ } else {
+ log.Debug().Msg("Discarded SMB session")
+ }
+ }()
+
+ } else if c.conn != nil {
+
+ defer func() {
+ if err = c.conn.Close(); err != nil {
+ log.Debug().Err(err).Msgf("Failed to disconnect SMB client")
+ } else {
+ log.Debug().Msg("Disconnected SMB client")
+ }
+ }()
+ }
+
+ // Unmount SMB share
+ if c.mount != nil {
+ defer func() {
+ if err = c.mount.Umount(); err != nil {
+ log.Debug().Err(err).Msg("Failed to unmount share")
+ } else {
+ log.Debug().Msg("Unmounted file share")
+ }
+ c.share = ""
+ }()
+ }
+ return
}
diff --git a/pkg/goexec/smb/input.go b/pkg/goexec/smb/input.go
new file mode 100644
index 0000000..b9cb3bc
--- /dev/null
+++ b/pkg/goexec/smb/input.go
@@ -0,0 +1,66 @@
+package smb
+
+import (
+ "context"
+ "fmt"
+ "github.com/FalconOpsLLC/goexec/pkg/goexec"
+ "io"
+ "os"
+ "path"
+ "strings"
+)
+
+type FileStager struct {
+ goexec.Cleaner
+
+ Client *Client
+
+ Share string
+ SharePath string
+ File string
+ relativePath string
+ ForceReconnect bool
+ DeleteStage bool
+}
+
+func (o *FileStager) Stage(ctx context.Context, reader io.Reader) (err error) {
+
+ o.relativePath = path.Join(
+ strings.ReplaceAll(pathPrefix.ReplaceAllString(o.SharePath, ""), `\`, "/"),
+ strings.ReplaceAll(pathPrefix.ReplaceAllString(o.File, ""), `\`, "/"),
+ )
+
+ if o.ForceReconnect || !o.Client.connected {
+ err = o.Client.Connect(ctx)
+ if err != nil {
+ return
+ }
+ defer o.AddCleaners(o.Client.Close)
+ }
+
+ if o.ForceReconnect || o.Client.share != o.Share {
+ err = o.Client.Mount(ctx, o.Share)
+ if err != nil {
+ return
+ }
+ }
+
+ writer, err := o.Client.mount.OpenFile(o.relativePath, os.O_WRONLY, 0644)
+ if err != nil {
+ return fmt.Errorf("open remote file for writing: %w", err)
+ }
+
+ if _, err = io.Copy(writer, reader); err != nil {
+ return
+ }
+
+ o.AddCleaners(func(_ context.Context) error { return writer.Close() })
+
+ if o.DeleteStage {
+ o.AddCleaners(func(_ context.Context) error {
+ return o.Client.mount.Remove(o.relativePath)
+ })
+ }
+
+ return
+}
diff --git a/pkg/goexec/smb/output.go b/pkg/goexec/smb/output.go
index f0a656d..d75cddb 100644
--- a/pkg/goexec/smb/output.go
+++ b/pkg/goexec/smb/output.go
@@ -4,9 +4,12 @@ import (
"context"
"errors"
"github.com/FalconOpsLLC/goexec/pkg/goexec"
+ "github.com/rs/zerolog"
"io"
"os"
+ "path/filepath"
"regexp"
+ "strings"
"time"
)
@@ -22,8 +25,10 @@ type OutputFileFetcher struct {
Client *Client
Share string
+ SharePath string
File string
DeleteOutputFile bool
+ ForceReconnect bool
PollInterval time.Duration
PollTimeout time.Duration
@@ -32,6 +37,8 @@ type OutputFileFetcher struct {
func (o *OutputFileFetcher) GetOutput(ctx context.Context, writer io.Writer) (err error) {
+ log := zerolog.Ctx(ctx)
+
if o.PollInterval == 0 {
o.PollInterval = DefaultOutputPollInterval
}
@@ -39,17 +46,28 @@ func (o *OutputFileFetcher) GetOutput(ctx context.Context, writer io.Writer) (er
o.PollTimeout = DefaultOutputPollTimeout
}
- o.relativePath = pathPrefix.ReplaceAllString(o.File, "")
+ shp := pathPrefix.ReplaceAllString(strings.ToLower(strings.ReplaceAll(o.SharePath, `\`, "/")), "")
+ fp := pathPrefix.ReplaceAllString(strings.ToLower(strings.ReplaceAll(o.File, `\`, "/")), "")
- err = o.Client.Connect(ctx)
- if err != nil {
+ if o.relativePath, err = filepath.Rel(shp, fp); err != nil {
return
}
- defer o.AddCleaners(o.Client.Close)
- err = o.Client.Mount(ctx, o.Share)
- if err != nil {
- return
+ log.Info().Str("path", o.relativePath).Msg("Fetching output file")
+
+ if o.ForceReconnect || !o.Client.connected {
+ err = o.Client.Connect(ctx)
+ if err != nil {
+ return
+ }
+ defer o.AddCleaners(o.Client.Close)
+ }
+
+ if o.ForceReconnect || o.Client.share != o.Share {
+ err = o.Client.Mount(ctx, o.Share)
+ if err != nil {
+ return
+ }
}
stopAt := time.Now().Add(o.PollTimeout)
@@ -57,7 +75,7 @@ func (o *OutputFileFetcher) GetOutput(ctx context.Context, writer io.Writer) (er
for {
if time.Now().After(stopAt) {
- return errors.New("output timeout")
+ return errors.New("execution output timeout")
}
if reader, err = o.Client.mount.OpenFile(o.relativePath, os.O_RDONLY, 0); err == nil {
break