200 lines
3.0 KiB
Go

package storage
import (
"bufio"
"encoding/json"
"io"
"log"
"os"
"path/filepath"
"sort"
"sync/atomic"
"time"
"git.akyoto.dev/go/ocean"
)
const (
diskWriteInterval = 100 * time.Millisecond
fileExtension = ".dat"
)
type File[T any] struct {
collection ocean.StorageData
dirty atomic.Uint32
sync chan struct{}
}
func (fs *File[T]) Init(c ocean.StorageData) error {
fs.collection = c
fs.sync = make(chan struct{})
go fs.flushWorker()
fileName := filepath.Join(c.Root(), c.Name()+fileExtension)
file, err := os.Open(fileName)
if os.IsNotExist(err) {
return nil
}
if err != nil {
return err
}
defer file.Close()
return fs.readFrom(file)
}
func (fs *File[T]) Delete(key string) error {
fs.dirty.Store(1)
return nil
}
func (fs *File[T]) Set(key string, value *T) error {
fs.dirty.Store(1)
return nil
}
func (fs *File[T]) Sync() {
<-fs.sync
}
func (fs *File[T]) flushWorker() {
for {
time.Sleep(diskWriteInterval)
if fs.dirty.Swap(0) == 0 {
select {
case fs.sync <- struct{}{}:
default:
}
continue
}
err := fs.flush()
if err != nil {
log.Println(err)
}
}
}
func (fs *File[T]) flush() error {
oldPath := filepath.Join(fs.collection.Root(), fs.collection.Name()+fileExtension)
newPath := oldPath + ".tmp"
file, err := os.OpenFile(newPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return err
}
bufferedWriter := bufio.NewWriter(file)
err = fs.writeTo(bufferedWriter)
if err != nil {
file.Close()
return err
}
err = bufferedWriter.Flush()
if err != nil {
file.Close()
return err
}
err = file.Sync()
if err != nil {
file.Close()
return err
}
err = file.Close()
if err != nil {
return err
}
return os.Rename(newPath, oldPath)
}
// readFrom reads the entire collection.
func (fs *File[T]) readFrom(stream io.Reader) error {
var (
key string
value []byte
)
scanner := bufio.NewScanner(stream)
for scanner.Scan() {
if key == "" {
key = scanner.Text()
continue
}
value = scanner.Bytes()
object := new(T)
err := json.Unmarshal(value, object)
if err != nil {
return err
}
fs.collection.Data().Store(key, object)
key = ""
}
return nil
}
// writeTo writes the entire collection.
func (fs *File[T]) writeTo(writer io.Writer) error {
stringWriter, ok := writer.(io.StringWriter)
if !ok {
panic("The given io.Writer is not an io.StringWriter")
}
records := []keyValue{}
fs.collection.Data().Range(func(key, value any) bool {
records = append(records, keyValue{
key: key.(string),
value: value,
})
return true
})
sort.Slice(records, func(i, j int) bool {
return records[i].key < records[j].key
})
encoder := NewEncoder(writer)
for _, record := range records {
_, err := stringWriter.WriteString(record.key)
if err != nil {
return err
}
_, err = stringWriter.WriteString("\n")
if err != nil {
return err
}
err = encoder.Encode(record.value)
if err != nil {
return err
}
}
return nil
}