Import first version

master
Robert Jacob 2022-03-16 00:36:43 +01:00
parent 553db94cd4
commit 252888f650
10 changed files with 719 additions and 0 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
leakybot
config.yaml
state.yaml

12
go.mod Normal file
View File

@ -0,0 +1,12 @@
module git.hacknology.de/projekte/leakybot
go 1.17
require (
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/sirupsen/logrus v1.8.1
github.com/spf13/pflag v1.0.5
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b
)
require golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 // indirect

18
go.sum Normal file
View File

@ -0,0 +1,18 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE=
github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 h1:YyJpGZS1sBuBCzLAR1VEpK193GlqGZbnPFnPV/5Rsb4=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

148
internal/config/config.go Normal file
View File

@ -0,0 +1,148 @@
package config
import (
"errors"
"fmt"
"os"
"time"
"github.com/sirupsen/logrus"
"github.com/spf13/pflag"
"gopkg.in/yaml.v3"
)
const (
envLogLevel = "LOG_LEVEL"
)
type Config struct {
ConfigFilePath string `yaml:"-"`
LogLevel logrus.Level `yaml:"-"`
Account *AccountConfig `yaml:"account"`
Rules *RulesConfig `yaml:"rules"`
}
func (c *Config) validate() error {
if c.Account == nil {
return errors.New("no account configuration")
}
if err := c.Account.validate(); err != nil {
return err
}
if c.Rules == nil {
return errors.New("no rules configuration")
}
return c.Rules.validate()
}
type AccountConfig struct {
HomeServer string `yaml:"home-server"`
Username string `yaml:"username"`
AccessToken string `yaml:"token"`
}
func (c *AccountConfig) validate() error {
if c.HomeServer == "" {
return errors.New("account.home-server can not be empty")
}
if c.Username == "" {
return errors.New("account.username can not be empty")
}
if c.AccessToken == "" {
return errors.New("account.token can not be empty")
}
return nil
}
type RulesConfig struct {
StateFilePath string `yaml:"state-file"`
WarnBucket BucketConfig `yaml:"warn"`
BanBucket BucketConfig `yaml:"ban"`
}
func (c *RulesConfig) validate() error {
if c.StateFilePath == "" {
return errors.New("rules.state-file can not be empty")
}
if err := c.WarnBucket.validate(); err != nil {
return fmt.Errorf("error in warn: %w", err)
}
if err := c.BanBucket.validate(); err != nil {
return fmt.Errorf("error in ban: %w", err)
}
return nil
}
type BucketConfig struct {
Duration time.Duration `yaml:"duration"`
Threshold int `yaml:"threshold"`
EffectDuration time.Duration `yaml:"effect"`
}
func (c BucketConfig) validate() error {
if c.Duration == 0 {
return errors.New("duration can not be zero")
}
if c.Threshold == 0 {
return errors.New("threshold can not be zero")
}
if c.EffectDuration < c.Duration {
return errors.New("effect-duration can not be smaller than duration")
}
return nil
}
func Get(cmd string, args []string, envFunc func(string) string) (*Config, error) {
cfg := &Config{
ConfigFilePath: "config.yaml",
LogLevel: logrus.InfoLevel,
}
flagSet := pflag.NewFlagSet(cmd, pflag.ContinueOnError)
flagSet.StringVarP(&cfg.ConfigFilePath, "config-file", "c", cfg.ConfigFilePath, "Path of configuration file.")
if err := flagSet.Parse(args); err != nil {
return nil, err
}
if cfg.ConfigFilePath == "" {
return nil, errors.New("--config-file can not be empty")
}
configFile, err := os.Open(cfg.ConfigFilePath)
if err != nil {
return nil, fmt.Errorf("can not open configuration file %q: %w", cfg.ConfigFilePath, err)
}
defer configFile.Close()
if err := yaml.NewDecoder(configFile).Decode(cfg); err != nil {
return nil, fmt.Errorf("error parsing configuration file %q: %w", cfg.ConfigFilePath, err)
}
if err := cfg.validate(); err != nil {
return nil, err
}
if raw := envFunc(envLogLevel); raw != "" {
lvl, err := logrus.ParseLevel(raw)
if err != nil {
return nil, fmt.Errorf("can not parse %s %q: %w", envLogLevel, raw, err)
}
cfg.LogLevel = lvl
}
return cfg, nil
}

View File

@ -0,0 +1,22 @@
package httputil
import (
"context"
"net/http"
)
type contextTransport struct {
ctx context.Context
inner http.RoundTripper
}
func ContextTransport(ctx context.Context, inner http.RoundTripper) http.RoundTripper {
return &contextTransport{
ctx: ctx,
inner: inner,
}
}
func (ct *contextTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return ct.inner.RoundTrip(req.WithContext(ct.ctx))
}

50
internal/leaky/bucket.go Normal file
View File

@ -0,0 +1,50 @@
package leaky
import (
"sync"
"time"
)
type Bucket struct {
sync.RWMutex
Clock func() time.Time
Duration time.Duration
Threshold int
Entries map[string][]time.Time
}
func NewBucket(clock func() time.Time, duration time.Duration, threshold int) *Bucket {
return &Bucket{
Clock: clock,
Duration: duration,
Threshold: threshold,
Entries: map[string][]time.Time{},
}
}
func (b *Bucket) Push(timeStamp time.Time, id string) (level int, ok bool) {
b.Lock()
defer b.Unlock()
now := b.Clock()
items, found := b.Entries[id]
if !found {
b.Entries[id] = []time.Time{
timeStamp,
}
return 1, true
}
newItems := []time.Time{timeStamp}
for _, t := range items {
if now.Sub(t) <= b.Duration {
newItems = append(newItems, t)
}
}
b.Entries[id] = newItems
level = len(newItems)
return level, level < b.Threshold
}

61
internal/power/power.go Normal file
View File

@ -0,0 +1,61 @@
package power
import (
"errors"
"net/http"
"github.com/matrix-org/gomatrix"
"github.com/sirupsen/logrus"
)
func ModifyLevel(log logrus.FieldLogger, client *gomatrix.Client, room, user string, level int) error {
u := client.BuildURL("rooms", room, "state", "m.room.power_levels")
levels := map[string]interface{}{}
err := client.MakeRequest(http.MethodGet, u, nil, &levels)
if err != nil {
return err
}
usersRaw, ok := levels["users"]
if !ok {
return errors.New("users field not found")
}
users, ok := usersRaw.(map[string]interface{})
if !ok {
return errors.New("users not a map")
}
users[user] = level
return sendLevel(log, client, room, levels)
}
func RemoveLevel(log logrus.FieldLogger, client *gomatrix.Client, room, user string) error {
u := client.BuildURL("rooms", room, "state", "m.room.power_levels")
levels := map[string]interface{}{}
err := client.MakeRequest(http.MethodGet, u, nil, &levels)
if err != nil {
return err
}
usersRaw, ok := levels["users"]
if !ok {
return errors.New("users field not found")
}
users, ok := usersRaw.(map[string]interface{})
if !ok {
return errors.New("users not a map")
}
delete(users, user)
return sendLevel(log, client, room, levels)
}
func sendLevel(log logrus.FieldLogger, client *gomatrix.Client, room string, levels map[string]interface{}) error {
u := client.BuildURL("rooms", room, "state", "m.room.power_levels")
resp := map[string]interface{}{}
return client.MakeRequest(http.MethodPut, u, levels, &resp)
}

125
internal/ruler/ruler.go Normal file
View File

@ -0,0 +1,125 @@
package ruler
import (
"context"
"fmt"
"sync"
"time"
"github.com/sirupsen/logrus"
"git.hacknology.de/projekte/leakybot/internal/config"
"git.hacknology.de/projekte/leakybot/internal/leaky"
)
type Funcs struct {
Warn func(id string) error
Unwarn func(id string) error
Ban func(id string) error
Unban func(id string) error
}
type Ruler struct {
log logrus.FieldLogger
cfg *config.RulesConfig
clock func() time.Time
funcs Funcs
warnBucket *leaky.Bucket
banBucket *leaky.Bucket
state *rulerState
}
func New(log logrus.FieldLogger, cfg *config.RulesConfig, clock func() time.Time, funcs Funcs) (*Ruler, error) {
state, err := restoreState(log, cfg.StateFilePath)
if err != nil {
return nil, fmt.Errorf("can not restore state: %w", err)
}
return &Ruler{
log: log,
cfg: cfg,
clock: clock,
funcs: funcs,
warnBucket: leaky.NewBucket(clock, cfg.WarnBucket.Duration, cfg.WarnBucket.Threshold),
banBucket: leaky.NewBucket(clock, cfg.BanBucket.Duration, cfg.BanBucket.Threshold),
state: state,
}, nil
}
func (r *Ruler) Start(ctx context.Context, wg *sync.WaitGroup) {
wg.Add(1)
go func() {
defer wg.Done()
t := time.NewTicker(time.Second)
for {
if ctx.Err() != nil {
r.log.Debugf("Shutting down ruler...")
if err := saveState(r.state, r.cfg.StateFilePath); err != nil {
r.log.Errorf("Error saving ruler state: %s", err)
}
return
}
select {
case <-ctx.Done():
case ts := <-t.C:
unwarn, unban := r.state.CheckExpired(ts)
for _, id := range unwarn {
r.log.Infof("Reset warning for %s", id)
if err := r.funcs.Unwarn(id); err != nil {
r.log.Errorf("Error unwarning %q: %s", id, err)
}
r.state.Unwarn(id)
}
for _, id := range unban {
r.log.Infof("Reset ban for %s", id)
if err := r.funcs.Unban(id); err != nil {
r.log.Errorf("Error unbanning %q: %s", id, err)
}
r.state.Unban(id)
}
}
}
}()
}
func (r *Ruler) PushEvent(ts time.Time, id string) error {
if r.state.IsBanned(id) || r.state.IsWarned(id) {
return nil
}
warnLevel, ok := r.warnBucket.Push(ts, id)
r.log.Debugf("Warning level for %s = %d", id, warnLevel)
if !ok {
now := r.clock()
banLevel, banOk := r.banBucket.Push(now, id)
r.log.Debugf("Ban level for %s = %d", id, banLevel)
if banOk {
until := now.Add(r.cfg.WarnBucket.EffectDuration)
r.log.Infof("Warning user %s until %s", id, until)
if err := r.funcs.Warn(id); err == nil {
r.state.SetWarned(id, until)
} else {
r.log.Errorf("Error warning %q: %s", id, err)
}
} else {
banUntil := now.Add(r.cfg.BanBucket.EffectDuration)
r.log.Infof("Banning user %s until %s", id, banUntil)
if err := r.funcs.Ban(id); err == nil {
r.state.SetBanned(id, banUntil)
} else {
r.log.Errorf("Error banning %q: %s", id, err)
}
}
}
return nil
}

137
internal/ruler/state.go Normal file
View File

@ -0,0 +1,137 @@
package ruler
import (
"fmt"
"os"
"sync"
"time"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v3"
)
type rulerState struct {
mutex *sync.RWMutex
warned map[string]time.Time
banned map[string]time.Time
}
type savedState struct {
Warned map[string]time.Time `yaml:"warned"`
Banned map[string]time.Time `yaml:"banned"`
}
func restoreState(log logrus.FieldLogger, fileName string) (*rulerState, error) {
state := rulerState{
mutex: &sync.RWMutex{},
warned: map[string]time.Time{},
banned: map[string]time.Time{},
}
file, err := os.Open(fileName)
if os.IsNotExist(err) {
log.Debugf("No state restored from %s", fileName)
return &state, nil
}
if err != nil {
return nil, fmt.Errorf("error opening file: %w", err)
}
defer file.Close()
saved := savedState{}
if err := yaml.NewDecoder(file).Decode(&saved); err != nil {
return nil, fmt.Errorf("error decoding state: %w", err)
}
for k, v := range saved.Warned {
state.warned[k] = v
}
for k, v := range saved.Banned {
state.banned[k] = v
}
return &state, nil
}
func saveState(state *rulerState, fileName string) error {
file, err := os.Create(fileName)
if err != nil {
return fmt.Errorf("error creating file: %w", err)
}
defer file.Close()
saved := savedState{
Warned: state.warned,
Banned: state.banned,
}
if err := yaml.NewEncoder(file).Encode(saved); err != nil {
return fmt.Errorf("error writing state: %w", err)
}
return nil
}
func (s *rulerState) IsBanned(id string) bool {
s.mutex.RLock()
defer s.mutex.RUnlock()
_, ok := s.banned[id]
return ok
}
func (s *rulerState) IsWarned(id string) bool {
s.mutex.RLock()
defer s.mutex.RUnlock()
_, ok := s.warned[id]
return ok
}
func (s *rulerState) SetWarned(id string, until time.Time) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.warned[id] = until
}
func (s *rulerState) Unwarn(id string) {
s.mutex.Lock()
defer s.mutex.Unlock()
delete(s.warned, id)
}
func (s *rulerState) SetBanned(id string, until time.Time) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.banned[id] = until
}
func (s *rulerState) Unban(id string) {
s.mutex.Lock()
defer s.mutex.Unlock()
delete(s.banned, id)
}
func (s *rulerState) CheckExpired(ts time.Time) (unwarn, unban []string) {
s.mutex.RLock()
defer s.mutex.RUnlock()
for id, until := range s.warned {
if until.Before(ts) {
unwarn = append(unwarn, id)
}
}
for id, until := range s.banned {
if until.Before(ts) {
unban = append(unban, id)
}
}
return unwarn, unban
}

143
main.go Normal file
View File

@ -0,0 +1,143 @@
package main
import (
"context"
"net/http"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"github.com/matrix-org/gomatrix"
"github.com/sirupsen/logrus"
"git.hacknology.de/projekte/leakybot/internal/config"
"git.hacknology.de/projekte/leakybot/internal/httputil"
"git.hacknology.de/projekte/leakybot/internal/power"
"git.hacknology.de/projekte/leakybot/internal/ruler"
)
const (
idSeparator = "--"
)
var (
signals = []os.Signal{syscall.SIGINT, syscall.SIGTERM}
log = &logrus.Logger{
Out: os.Stderr,
Formatter: &logrus.TextFormatter{
DisableTimestamp: true,
},
Hooks: make(logrus.LevelHooks),
Level: logrus.InfoLevel,
ExitFunc: os.Exit,
ReportCaller: false,
}
)
func formatID(user, room string) string {
return strings.Join([]string{user, room}, idSeparator)
}
func splitID(id string) (user, room string) {
tokens := strings.SplitN(id, idSeparator, 2)
return tokens[0], tokens[1]
}
func main() {
cfg, err := config.Get(os.Args[0], os.Args[1:], os.Getenv)
if err != nil {
log.Fatalf("Error getting config: %s", err)
}
log.SetLevel(cfg.LogLevel)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
wg := &sync.WaitGroup{}
log.Infof("Homeserver: %s User: %s", cfg.Account.HomeServer, cfg.Account.Username)
client, err := gomatrix.NewClient(cfg.Account.HomeServer, cfg.Account.Username, cfg.Account.AccessToken)
if err != nil {
log.Fatalf("Error creating client: %s", err)
}
client.Client = &http.Client{
Timeout: 60 * time.Second,
Transport: httputil.ContextTransport(ctx, http.DefaultTransport),
}
rooms, err := client.JoinedRooms()
if err != nil {
log.Fatalf("Error retrieving rooms: %s", err)
}
for _, r := range rooms.JoinedRooms {
log.Debugf("Found room: %s", r)
}
rulerFuncs := ruler.Funcs{
Warn: func(id string) error {
user, room := splitID(id)
return power.ModifyLevel(log, client, room, user, -1)
},
Unwarn: func(id string) error {
user, room := splitID(id)
return power.RemoveLevel(log, client, room, user)
},
Ban: func(id string) error {
user, room := splitID(id)
_, err := client.BanUser(room, &gomatrix.ReqBanUser{
Reason: "spam",
UserID: user,
})
return err
},
Unban: func(string) error {
return nil
},
}
r, err := ruler.New(log, cfg.Rules, time.Now, rulerFuncs)
if err != nil {
log.Fatalf("Error creating ruler: %s", err)
}
syncer := client.Syncer.(*gomatrix.DefaultSyncer)
syncer.OnEventType("m.room.message", func(evt *gomatrix.Event) {
id := formatID(evt.Sender, evt.RoomID)
timeStamp := time.Unix(evt.Timestamp/1000, 0)
log.Debugf("Message from %s @ %s", id, timeStamp)
if err := r.PushEvent(timeStamp, id); err != nil {
log.Errorf("Error pushing event to ruler: %s", err)
}
})
sigCh := make(chan os.Signal)
signal.Notify(sigCh, signals...)
go func() {
sig := <-sigCh
signal.Reset(signals...)
log.Debugf("Got signal %q. Terminating...", sig)
cancel()
}()
r.Start(ctx, wg)
log.Debug("Starting synchronization loop...")
for {
if ctx.Err() != nil {
break
}
if err := client.Sync(); err != nil {
log.Errorf("error during sync: %s", err)
}
}
}