aboutsummaryrefslogtreecommitdiff
path: root/pkg/goexec/smb/output.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/goexec/smb/output.go')
-rw-r--r--pkg/goexec/smb/output.go36
1 files changed, 27 insertions, 9 deletions
diff --git a/pkg/goexec/smb/output.go b/pkg/goexec/smb/output.go
index 35c40c8..25768d1 100644
--- a/pkg/goexec/smb/output.go
+++ b/pkg/goexec/smb/output.go
@@ -2,6 +2,7 @@ package smb
import (
"context"
+ "errors"
"github.com/FalconOpsLLC/goexec/pkg/goexec"
"io"
"os"
@@ -12,22 +13,24 @@ import (
var (
DefaultOutputPollInterval = 1 * time.Second
DefaultOutputPollTimeout = 60 * time.Second
- pathPrefix = regexp.MustCompile(`^([a-zA-Z]:)?\\*`)
+ pathPrefix = regexp.MustCompile(`^([a-zA-Z]:)?[\\/]*`)
)
type OutputFileFetcher struct {
goexec.Cleaner
- Client *Client
- Share string
- File string
- PollInterval time.Duration
- PollTimeout time.Duration
+ Client *Client
+
+ Share string
+ File string
+ DeleteOutputFile bool
+ PollInterval time.Duration
+ PollTimeout time.Duration
relativePath string
}
-func (o *OutputFileFetcher) GetOutput(ctx context.Context) (reader io.ReadCloser, err error) {
+func (o *OutputFileFetcher) GetOutput(ctx context.Context, writer io.Writer) (err error) {
if o.PollInterval == 0 {
o.PollInterval = DefaultOutputPollInterval
@@ -42,6 +45,7 @@ func (o *OutputFileFetcher) GetOutput(ctx context.Context) (reader io.ReadCloser
if err != nil {
return
}
+ defer o.AddCleaner(o.Client.Close)
err = o.Client.Mount(ctx, o.Share)
if err != nil {
@@ -49,15 +53,29 @@ func (o *OutputFileFetcher) GetOutput(ctx context.Context) (reader io.ReadCloser
}
stopAt := time.Now().Add(o.PollTimeout)
+ var reader io.ReadCloser
for {
if time.Now().After(stopAt) {
- return
+ return errors.New("output timeout")
}
if reader, err = o.Client.mount.OpenFile(o.relativePath, os.O_RDONLY, 0); err == nil {
- return
+ break
}
time.Sleep(o.PollInterval)
}
+
+ if _, err = io.Copy(writer, reader); err != nil {
+ return
+ }
+
+ o.AddCleaner(func(_ context.Context) error { return reader.Close() })
+
+ if o.DeleteOutputFile {
+ o.AddCleaner(func(_ context.Context) error {
+ return o.Client.mount.Remove(o.relativePath)
+ })
+ }
+
return
}