.globl add64
.globl mul
.globl div
.globl divmod
.globl clz

.text

# > When primitive arguments twice the size of a pointer-word are passed on the
# > stack, they are naturally aligned. When they are passed in the integer
# > registers, they reside in an aligned even-odd register pair, with the even
# > register holding the least-significant bits.

# 64-bit integer addition
#   arguments:
#       a0: x lower 32 bits
#       a1: x upper 32 bits
#       a2: y lower 32 bits
#       a3: y upper 32 bits
#   return:
#       a0: x+y lower 32 bits
#       a1: x+y upper 32 bits
#
add64:
    add  a0, a0, a2  # add lower 32 bits
    add  t0, a1, a3  # add upper 32 bits
    sltu t1, a0, a2  # if lower 32-bit sum < a2 then set t1=1 (carry bit)
    add  a1, t0, t1  # upper 32 bits of answer (upper sum + carry bit)
    ret

# 32-bit shift-add multiplication
#    arguments:
#        a0: multiplicand
#        a1: multiplier
#    return:
#        a0 = a0 × a1
#
mul:
    mv      t0, a1          # Save multiplier in t0
    li      a1, 0           # Initialize product in a1

.multiply_loop:
    beqz    t0, .done       # If multiplier is 0, we're done
    andi    t1, t0, 1       # Check least significant bit
    beqz    t1, .shift      # If LSB is 0, skip addition
    add     a1, a1, a0      # Add multiplicand to product

.shift:
    slli    a0, a0, 1       # Shift multiplicand left
    srli    t0, t0, 1       # Shift multiplier right
    j       .multiply_loop  # Continue loop

.done:
    mv      a0, a1          # Move product to return register
    ret

# 32-bit shift-subtract integer division
#    arguments:
#        a0: dividend, u
#        a1: divisor, v
#    return:
#        a0 = a0 ÷ a1
#
# https://blog.segger.com/algorithms-for-division-part-2-classics/
div:
    bltu    a0, a1, .zero    # if (u < v) return 0
    addi    sp, sp, -16
    sw      s1, 4(sp)
    mv      s1, a0
    mv      a0, a1
    sw      ra, 12(sp)
    sw      s0, 8(sp)
    mv      s0, a1
    jal     clz              # clz(u)
    sw      a0, 0(sp)
    mv      a0, s1
    jal     clz              # clz(v)
    lw      a5, 0(sp)
    sub     a5, a5, a0       # k = clz(v) - clz(u); Calculate number of quotient digits - 1
    sll     a1, s0, a5       # v <<= k;             Normalize divisor
    li      a0, 0            # q = 0;               Init quotient

    # Iterate k+1 times, each iteration developing one quotient bit.
.loop:
    slli    a0, a0, 1        # q <<= 1;             Record preliminary '0' quotient digit
    bltu    s1, a1, .skip    # if (u >= v)          Subtraction will succeed...
    sub     s1, s1, a1       # u -= v;
    addi    a0, a0, 1        # q += 1;              Turn preliminary '0' quotient digit to '1'
.skip:
    addi    a5, a5, -1       # k -= 1;
    srli    a1, a1, 1        # v >>= 1;
    bgez    a5, .loop        # while (k >= 0);

    lw      ra, 12(sp)
    lw      s0, 8(sp)
    lw      s1, 4(sp)
    addi    sp, sp, 16
    ret

.zero:
    li      a0, 0
    ret

# 32-bit integer division with modulus (remainder)
#    arguments:
#        a0: dividend, u
#        a1: divisor, v
#    return:
#        a0 = a0 ÷ a1
#        a1 = remainder
divmod:
    # call div; multiply quotient by divisor; subtract that from dividend
    addi    sp, sp, -16
    sw      ra, 12(sp)
    sw      s0, 8(sp)
    sw      s1, 4(sp)
    mv      s0, a0      # save dividend
    mv      s1, a1      # save divisor
    jal div             # a0 = a0 ÷ a1
    sw      a0, 0(sp)
    mv      a1, a0      # a1 = quotient
    mv      a0, s1      # a0 = divisor
    jal     mul         # a0 = divisor × quotient
    lw      ra, 12(sp)
    sub     a1, s0, a0  # a1 = dividend - product; remainder
    lw      s0, 8(sp)
    lw      s1, 4(sp)
    lw      a0, 0(sp)   # a0 = quotient
    addi    sp, sp, 16
    ret

# count leading zero bits
#    arguments:
#        a0: input
#    return:
#        a0 = count of leading zero bits
#
# binary search approach translated from C code on
# https://blog.stephencleary.com/2010/10/implementing-gccs-builtin-functions.html
clz:
        li      a4, 16             # initialise count of zeros to 16
        srli    a5, a0, 16         # shift value right 16 bits
        bne     a5, zero, .eight   # if the result is != 0 we have up to 16 leading zeros
        mv      a5, a0             # restore unshifted value to a5
        li      a4, 32             # we have up to 32 leading zeros
.eight:
        srli    a3, a5, 8          # shift the value right 8 bits
        beq     a3, zero, .four
        addi    a4, a4, -8         # subtract 8 leading zeros if shift result was non-zero
        mv      a5, a3
.four:
        srli    a3, a5, 4          # shift the value right 4 bits
        beq     a3, zero, .two
        addi    a4, a4, -4         # subtract 4 leading zeros if shift result was non-zero
        mv      a5, a3
.two:
        srli    a3, a5, 2          # shift the value right 2 bits
        beq     a3, zero, .one
        addi    a4, a4, -2         # subtract 2 leading zeros if shift result was non-zero
        mv      a5, a3
.one:
        srli    a3, a5, 1          # shift the value right 1 bit
        sub     a0, a4, a5         # a0 = count - remaining value
        beq     a3, zero, .end     # if shift result was zero, return a0
        addi    a0, a4, -2         # subtract 2 leading zeros if shift result was non-zero
.end:
        ret