package asm import ( "encoding/binary" "fmt" "math" "slices" "strings" "git.akyoto.dev/cli/q/src/config" "git.akyoto.dev/cli/q/src/dll" "git.akyoto.dev/cli/q/src/elf" "git.akyoto.dev/cli/q/src/fs" "git.akyoto.dev/cli/q/src/macho" "git.akyoto.dev/cli/q/src/pe" "git.akyoto.dev/cli/q/src/sizeof" "git.akyoto.dev/cli/q/src/x64" ) // Finalize generates the final machine code. func (a Assembler) Finalize(dlls dll.List) ([]byte, []byte) { var ( code = make([]byte, 0, len(a.Instructions)*8) data []byte codeLabels = map[string]Address{} dataLabels map[string]Address codePointers []*Pointer dataPointers []*Pointer funcPointers []*Pointer dllPointers []*Pointer ) for _, x := range a.Instructions { switch x.Mnemonic { case ADD: switch operands := x.Data.(type) { case *RegisterNumber: code = x64.AddRegisterNumber(code, operands.Register, operands.Number) case *RegisterRegister: code = x64.AddRegisterRegister(code, operands.Destination, operands.Source) } case AND: switch operands := x.Data.(type) { case *RegisterNumber: code = x64.AndRegisterNumber(code, operands.Register, operands.Number) case *RegisterRegister: code = x64.AndRegisterRegister(code, operands.Destination, operands.Source) } case SUB: switch operands := x.Data.(type) { case *RegisterNumber: code = x64.SubRegisterNumber(code, operands.Register, operands.Number) case *RegisterRegister: code = x64.SubRegisterRegister(code, operands.Destination, operands.Source) } case MUL: switch operands := x.Data.(type) { case *RegisterNumber: code = x64.MulRegisterNumber(code, operands.Register, operands.Number) case *RegisterRegister: code = x64.MulRegisterRegister(code, operands.Destination, operands.Source) } case DIV: switch operands := x.Data.(type) { case *RegisterRegister: if operands.Destination != x64.RAX { code = x64.MoveRegisterRegister(code, x64.RAX, operands.Destination) } code = x64.ExtendRAXToRDX(code) code = x64.DivRegister(code, operands.Source) if operands.Destination != x64.RAX { code = x64.MoveRegisterRegister(code, operands.Destination, x64.RAX) } } case MODULO: switch operands := x.Data.(type) { case *RegisterRegister: if operands.Destination != x64.RAX { code = x64.MoveRegisterRegister(code, x64.RAX, operands.Destination) } code = x64.ExtendRAXToRDX(code) code = x64.DivRegister(code, operands.Source) if operands.Destination != x64.RDX { code = x64.MoveRegisterRegister(code, operands.Destination, x64.RDX) } } case CALL: code = x64.Call(code, 0x00_00_00_00) size := 4 label := x.Data.(*Label) pointer := &Pointer{ Position: Address(len(code) - size), OpSize: 1, Size: uint8(size), } pointer.Resolve = func() Address { destination, exists := codeLabels[label.Name] if !exists { panic("unknown jump label") } distance := destination - (pointer.Position + Address(pointer.Size)) return Address(distance) } codePointers = append(codePointers, pointer) case COMMENT: continue case COMPARE: switch operands := x.Data.(type) { case *RegisterNumber: code = x64.CompareRegisterNumber(code, operands.Register, operands.Number) case *RegisterRegister: code = x64.CompareRegisterRegister(code, operands.Destination, operands.Source) } case DLLCALL: size := 4 // TODO: R15 could be in use. code = x64.MoveRegisterRegister(code, x64.R15, x64.RSP) code = x64.AlignStack(code) code = x64.SubRegisterNumber(code, x64.RSP, 32) code = x64.CallAtAddress(code, 0x00_00_00_00) position := len(code) - size code = x64.MoveRegisterRegister(code, x64.RSP, x64.R15) label := x.Data.(*Label) pointer := &Pointer{ Position: Address(position), OpSize: 2, Size: uint8(size), } pointer.Resolve = func() Address { dot := strings.Index(label.Name, ".") library := label.Name[:dot] funcName := label.Name[dot+1:] index := dlls.Index(library, funcName) if index == -1 { panic("unknown DLL function " + label.Name) } return Address(index * 8) } dllPointers = append(dllPointers, pointer) case JE, JNE, JG, JGE, JL, JLE, JUMP: switch x.Mnemonic { case JE: code = x64.Jump8IfEqual(code, 0x00) case JNE: code = x64.Jump8IfNotEqual(code, 0x00) case JG: code = x64.Jump8IfGreater(code, 0x00) case JGE: code = x64.Jump8IfGreaterOrEqual(code, 0x00) case JL: code = x64.Jump8IfLess(code, 0x00) case JLE: code = x64.Jump8IfLessOrEqual(code, 0x00) case JUMP: code = x64.Jump8(code, 0x00) } size := 1 label := x.Data.(*Label) pointer := &Pointer{ Position: Address(len(code) - size), OpSize: 1, Size: uint8(size), } pointer.Resolve = func() Address { destination, exists := codeLabels[label.Name] if !exists { panic("unknown jump label") } distance := destination - (pointer.Position + Address(pointer.Size)) return Address(distance) } codePointers = append(codePointers, pointer) case LABEL: codeLabels[x.Data.(*Label).Name] = Address(len(code)) case LOAD: switch operands := x.Data.(type) { case *MemoryRegister: code = x64.LoadRegister(code, operands.Register, operands.Address.Offset, operands.Address.Length, operands.Address.Base) } case MOVE: switch operands := x.Data.(type) { case *RegisterNumber: code = x64.MoveRegisterNumber(code, operands.Register, operands.Number) case *RegisterRegister: code = x64.MoveRegisterRegister(code, operands.Destination, operands.Source) case *RegisterLabel: start := len(code) code = x64.MoveRegisterNumber(code, operands.Register, 0x00_00_00_00) size := 4 opSize := len(code) - size - start regLabel := x.Data.(*RegisterLabel) if strings.HasPrefix(regLabel.Label, "data_") { dataPointers = append(dataPointers, &Pointer{ Position: Address(len(code) - size), OpSize: uint8(opSize), Size: uint8(size), Resolve: func() Address { destination, exists := dataLabels[regLabel.Label] if !exists { panic("unknown label") } return Address(destination) }, }) } else { funcPointers = append(funcPointers, &Pointer{ Position: Address(len(code) - size), OpSize: uint8(opSize), Size: uint8(size), Resolve: func() Address { destination, exists := codeLabels[regLabel.Label] if !exists { panic("unknown label") } return Address(destination) }, }) } } case NEGATE: switch operands := x.Data.(type) { case *Register: code = x64.NegateRegister(code, operands.Register) } case OR: switch operands := x.Data.(type) { case *RegisterNumber: code = x64.OrRegisterNumber(code, operands.Register, operands.Number) case *RegisterRegister: code = x64.OrRegisterRegister(code, operands.Destination, operands.Source) } case POP: switch operands := x.Data.(type) { case *Register: code = x64.PopRegister(code, operands.Register) } case PUSH: switch operands := x.Data.(type) { case *Register: code = x64.PushRegister(code, operands.Register) } case RETURN: code = x64.Return(code) case SHIFTL: switch operands := x.Data.(type) { case *RegisterNumber: code = x64.ShiftLeftNumber(code, operands.Register, byte(operands.Number)&0b111111) } case SHIFTRS: switch operands := x.Data.(type) { case *RegisterNumber: code = x64.ShiftRightSignedNumber(code, operands.Register, byte(operands.Number)&0b111111) } case STORE: switch operands := x.Data.(type) { case *MemoryNumber: if operands.Address.OffsetRegister == math.MaxUint8 { code = x64.StoreNumber(code, operands.Address.Base, operands.Address.Offset, operands.Address.Length, operands.Number) } else { code = x64.StoreDynamicNumber(code, operands.Address.Base, operands.Address.OffsetRegister, operands.Address.Length, operands.Number) } case *MemoryLabel: start := len(code) if operands.Address.OffsetRegister == math.MaxUint8 { code = x64.StoreNumber(code, operands.Address.Base, operands.Address.Offset, operands.Address.Length, 0b00_00_00_00) } else { code = x64.StoreDynamicNumber(code, operands.Address.Base, operands.Address.OffsetRegister, operands.Address.Length, 0b00_00_00_00) } size := 4 opSize := len(code) - size - start memLabel := x.Data.(*MemoryLabel) funcPointers = append(funcPointers, &Pointer{ Position: Address(len(code) - size), OpSize: uint8(opSize), Size: uint8(size), Resolve: func() Address { destination, exists := codeLabels[memLabel.Label] if !exists { panic("unknown label") } return Address(destination) }, }) case *MemoryRegister: if operands.Address.OffsetRegister == math.MaxUint8 { code = x64.StoreRegister(code, operands.Address.Base, operands.Address.Offset, operands.Address.Length, operands.Register) } else { code = x64.StoreDynamicRegister(code, operands.Address.Base, operands.Address.OffsetRegister, operands.Address.Length, operands.Register) } } case SYSCALL: code = x64.Syscall(code) case XOR: switch operands := x.Data.(type) { case *RegisterNumber: code = x64.XorRegisterNumber(code, operands.Register, operands.Number) case *RegisterRegister: code = x64.XorRegisterRegister(code, operands.Destination, operands.Source) } default: panic("unknown mnemonic: " + x.Mnemonic.String()) } } restart: for i, pointer := range codePointers { address := pointer.Resolve() if sizeof.Signed(int64(address)) > int(pointer.Size) { left := code[:pointer.Position-Address(pointer.OpSize)] right := code[pointer.Position+Address(pointer.Size):] size := pointer.Size + pointer.OpSize opCode := code[pointer.Position-Address(pointer.OpSize)] var jump []byte switch opCode { case 0x74: // JE jump = []byte{0x0F, 0x84} case 0x75: // JNE jump = []byte{0x0F, 0x85} case 0x7C: // JL jump = []byte{0x0F, 0x8C} case 0x7D: // JGE jump = []byte{0x0F, 0x8D} case 0x7E: // JLE jump = []byte{0x0F, 0x8E} case 0x7F: // JG jump = []byte{0x0F, 0x8F} case 0xEB: // JMP jump = []byte{0xE9} default: panic(fmt.Errorf("failed to increase pointer size for instruction 0x%x", opCode)) } pointer.Position += Address(len(jump) - int(pointer.OpSize)) pointer.OpSize = uint8(len(jump)) pointer.Size = 4 jump = binary.LittleEndian.AppendUint32(jump, uint32(address)) offset := Address(len(jump)) - Address(size) for _, following := range codePointers[i+1:] { following.Position += offset } for key, address := range codeLabels { if address > pointer.Position { codeLabels[key] += offset } } code = slices.Concat(left, jump, right) goto restart } slice := code[pointer.Position : pointer.Position+Address(pointer.Size)] switch pointer.Size { case 1: slice[0] = uint8(address) case 2: binary.LittleEndian.PutUint16(slice, uint16(address)) case 4: binary.LittleEndian.PutUint32(slice, uint32(address)) case 8: binary.LittleEndian.PutUint64(slice, uint64(address)) } } headerEnd := Address(0) switch config.TargetOS { case "linux": headerEnd = elf.HeaderEnd case "macos": headerEnd = macho.HeaderEnd case "windows": headerEnd = pe.HeaderEnd } codeStart, _ := fs.Align(headerEnd, config.Align) dataStart, _ := fs.Align(codeStart+Address(len(code)), config.Align) data, dataLabels = a.Data.Finalize() for _, pointer := range funcPointers { address := config.BaseAddress + Address(codeStart) + pointer.Resolve() slice := code[pointer.Position : pointer.Position+4] binary.LittleEndian.PutUint32(slice, uint32(address)) } for _, pointer := range dataPointers { address := config.BaseAddress + Address(dataStart) + pointer.Resolve() slice := code[pointer.Position : pointer.Position+4] binary.LittleEndian.PutUint32(slice, uint32(address)) } if config.TargetOS == "windows" { if len(data) == 0 { data = []byte{0} } importsStart, _ := fs.Align(dataStart+Address(len(data)), config.Align) for _, pointer := range dllPointers { destination := Address(importsStart) + pointer.Resolve() delta := destination - Address(codeStart+pointer.Position+Address(pointer.Size)) slice := code[pointer.Position : pointer.Position+4] binary.LittleEndian.PutUint32(slice, uint32(delta)) } } return code, data }