aboutsummaryrefslogtreecommitdiff
path: root/pkg/goexec/smb/output.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/goexec/smb/output.go')
-rw-r--r--pkg/goexec/smb/output.go34
1 files changed, 26 insertions, 8 deletions
diff --git a/pkg/goexec/smb/output.go b/pkg/goexec/smb/output.go
index f0a656d..d75cddb 100644
--- a/pkg/goexec/smb/output.go
+++ b/pkg/goexec/smb/output.go
@@ -4,9 +4,12 @@ import (
"context"
"errors"
"github.com/FalconOpsLLC/goexec/pkg/goexec"
+ "github.com/rs/zerolog"
"io"
"os"
+ "path/filepath"
"regexp"
+ "strings"
"time"
)
@@ -22,8 +25,10 @@ type OutputFileFetcher struct {
Client *Client
Share string
+ SharePath string
File string
DeleteOutputFile bool
+ ForceReconnect bool
PollInterval time.Duration
PollTimeout time.Duration
@@ -32,6 +37,8 @@ type OutputFileFetcher struct {
func (o *OutputFileFetcher) GetOutput(ctx context.Context, writer io.Writer) (err error) {
+ log := zerolog.Ctx(ctx)
+
if o.PollInterval == 0 {
o.PollInterval = DefaultOutputPollInterval
}
@@ -39,17 +46,28 @@ func (o *OutputFileFetcher) GetOutput(ctx context.Context, writer io.Writer) (er
o.PollTimeout = DefaultOutputPollTimeout
}
- o.relativePath = pathPrefix.ReplaceAllString(o.File, "")
+ shp := pathPrefix.ReplaceAllString(strings.ToLower(strings.ReplaceAll(o.SharePath, `\`, "/")), "")
+ fp := pathPrefix.ReplaceAllString(strings.ToLower(strings.ReplaceAll(o.File, `\`, "/")), "")
- err = o.Client.Connect(ctx)
- if err != nil {
+ if o.relativePath, err = filepath.Rel(shp, fp); err != nil {
return
}
- defer o.AddCleaners(o.Client.Close)
- err = o.Client.Mount(ctx, o.Share)
- if err != nil {
- return
+ log.Info().Str("path", o.relativePath).Msg("Fetching output file")
+
+ if o.ForceReconnect || !o.Client.connected {
+ err = o.Client.Connect(ctx)
+ if err != nil {
+ return
+ }
+ defer o.AddCleaners(o.Client.Close)
+ }
+
+ if o.ForceReconnect || o.Client.share != o.Share {
+ err = o.Client.Mount(ctx, o.Share)
+ if err != nil {
+ return
+ }
}
stopAt := time.Now().Add(o.PollTimeout)
@@ -57,7 +75,7 @@ func (o *OutputFileFetcher) GetOutput(ctx context.Context, writer io.Writer) (er
for {
if time.Now().After(stopAt) {
- return errors.New("output timeout")
+ return errors.New("execution output timeout")
}
if reader, err = o.Client.mount.OpenFile(o.relativePath, os.O_RDONLY, 0); err == nil {
break