nebula/timeout.go

194 lines
4.4 KiB
Go

package nebula
import (
"time"
"github.com/slackhq/nebula/firewall"
)
// How many timer objects should be cached
const timerCacheMax = 50000
var emptyFWPacket = firewall.Packet{}
type TimerWheel struct {
// Current tick
current int
// Cheat on finding the length of the wheel
wheelLen int
// Last time we ticked, since we are lazy ticking
lastTick *time.Time
// Durations of a tick and the entire wheel
tickDuration time.Duration
wheelDuration time.Duration
// The actual wheel which is just a set of singly linked lists, head/tail pointers
wheel []*TimeoutList
// Singly linked list of items that have timed out of the wheel
expired *TimeoutList
// Item cache to avoid garbage collect
itemCache *TimeoutItem
itemsCached int
}
// Represents a tick in the wheel
type TimeoutList struct {
Head *TimeoutItem
Tail *TimeoutItem
}
// Represents an item within a tick
type TimeoutItem struct {
Packet firewall.Packet
Next *TimeoutItem
}
// Builds a timer wheel and identifies the tick duration and wheel duration from the provided values
// Purge must be called once per entry to actually remove anything
func NewTimerWheel(min, max time.Duration) *TimerWheel {
//TODO provide an error
//if min >= max {
// return nil
//}
// Round down and add 1 so we can have the smallest # of ticks in the wheel and still account for a full
// max duration
wLen := int((max / min) + 1)
tw := TimerWheel{
wheelLen: wLen,
wheel: make([]*TimeoutList, wLen),
tickDuration: min,
wheelDuration: max,
expired: &TimeoutList{},
}
for i := range tw.wheel {
tw.wheel[i] = &TimeoutList{}
}
return &tw
}
// Add will add a firewall.Packet to the wheel in it's proper timeout
func (tw *TimerWheel) Add(v firewall.Packet, timeout time.Duration) *TimeoutItem {
// Check and see if we should progress the tick
tw.advance(time.Now())
i := tw.findWheel(timeout)
// Try to fetch off the cache
ti := tw.itemCache
if ti != nil {
tw.itemCache = ti.Next
tw.itemsCached--
ti.Next = nil
} else {
ti = &TimeoutItem{}
}
// Relink and return
ti.Packet = v
if tw.wheel[i].Tail == nil {
tw.wheel[i].Head = ti
tw.wheel[i].Tail = ti
} else {
tw.wheel[i].Tail.Next = ti
tw.wheel[i].Tail = ti
}
return ti
}
func (tw *TimerWheel) Purge() (firewall.Packet, bool) {
if tw.expired.Head == nil {
return emptyFWPacket, false
}
ti := tw.expired.Head
tw.expired.Head = ti.Next
if tw.expired.Head == nil {
tw.expired.Tail = nil
}
// Clear out the items references
ti.Next = nil
// Maybe cache it for later
if tw.itemsCached < timerCacheMax {
ti.Next = tw.itemCache
tw.itemCache = ti
tw.itemsCached++
}
return ti.Packet, true
}
// advance will move the wheel forward by proper number of ticks. The caller _should_ lock the wheel before calling this
func (tw *TimerWheel) findWheel(timeout time.Duration) (i int) {
if timeout < tw.tickDuration {
// Can't track anything below the set resolution
timeout = tw.tickDuration
} else if timeout > tw.wheelDuration {
// We aren't handling timeouts greater than the wheels duration
timeout = tw.wheelDuration
}
// Find the next highest, rounding up
tick := int(((timeout - 1) / tw.tickDuration) + 1)
// Add another tick since the current tick may almost be over then map it to the wheel from our
// current position
tick += tw.current + 1
if tick >= tw.wheelLen {
tick -= tw.wheelLen
}
return tick
}
// advance will lock and move the wheel forward by proper number of ticks.
func (tw *TimerWheel) advance(now time.Time) {
if tw.lastTick == nil {
tw.lastTick = &now
}
// We want to round down
ticks := int(now.Sub(*tw.lastTick) / tw.tickDuration)
adv := ticks
if ticks > tw.wheelLen {
ticks = tw.wheelLen
}
for i := 0; i < ticks; i++ {
tw.current++
if tw.current >= tw.wheelLen {
tw.current = 0
}
if tw.wheel[tw.current].Head != nil {
// We need to append the expired items as to not starve evicting the oldest ones
if tw.expired.Tail == nil {
tw.expired.Head = tw.wheel[tw.current].Head
tw.expired.Tail = tw.wheel[tw.current].Tail
} else {
tw.expired.Tail.Next = tw.wheel[tw.current].Head
tw.expired.Tail = tw.wheel[tw.current].Tail
}
tw.wheel[tw.current].Head = nil
tw.wheel[tw.current].Tail = nil
}
}
// Advance the tick based on duration to avoid losing some accuracy
newTick := tw.lastTick.Add(tw.tickDuration * time.Duration(adv))
tw.lastTick = &newTick
}