StableMath.py

All functions in this class are static methods and can be called without instantiation.

calculate_invariant

Computes the invariant given the current balances, using the Newton-Raphson approximation.

def calculate_invariant(amplificationParameter: Decimal, balances: list) -> Decimal:
        
        bal_sum = 0
        for bal in balances:
            bal_sum += bal
        num_tokens = len(balances)
        if(bal_sum==0):
            return 0
        prevInvariant = 0
        invariant = bal_sum
        ampTimesTotal = amplificationParameter*num_tokens
        for i in range(255):
            P_D = num_tokens*balances[0]
            for j in range(1, num_tokens):
                P_D = ceil(((P_D*balances[j])*num_tokens)/invariant)
            prevInvariant = invariant

            invariant = ceil(((num_tokens*invariant)*invariant + (ampTimesTotal*bal_sum)*P_D) / ((num_tokens + 1)*invariant + (ampTimesTotal - 1)*P_D))
            if(invariant > prevInvariant):
                if(invariant - prevInvariant <= 1):
                    break
            elif(prevInvariant - invariant <= 1):
                break
        return Decimal(invariant)

calc_out_given_in

Computes how many tokens can be taken out of a pool if tokenAmountIn are sent, given the current balances.

A - amount of tokens (in and out) B - token balance in the pool (for "token in" and "token out") W - weights of these tokens inside the pool (for "token in" and "token out") - a

def calcOutGivenIn(amplificationParameter: Decimal, balances: list, tokenIndexIn: int, tokenIndexOut: int, tokenAmountIn: Decimal):

        invariant = StableMath.calculateInvariant(amplificationParameter, balances)
        balances[tokenIndexIn] = balances[tokenIndexIn] + tokenAmountIn
        finalBalanceOut = StableMath.getTokenBalanceGivenInvariantAndAllOtherBalances(amplificationParameter, balances, invariant, tokenIndexOut)
        balances[tokenIndexIn] = balances[tokenIndexIn] - tokenAmountIn
        result = balances [tokenIndexOut] - finalBalanceOut  

        return result

calc_in_given_out

This formula is used for the calculation of how many tokens "in" you need to send to the pool to receive the desired number of tokens "out" back.

def calc_in_given_out(amplificationParameter: Decimal, balances: list, tokenIndexIn: str, tokenIndexOut: str, tokenAmountOut: Decimal):

        invariant = StableMath.calculateInvariant(amplificationParameter, balances)
        balances[tokenIndexOut] = balances[tokenIndexOut] - tokenAmountOut

        finalBalanceIn = StableMath.getTokenBalanceGivenInvariantAndAllOtherBalances(
            amplificationParameter,
            balances,
            invariant,
            tokenIndexIn
        )  

        balances[tokenIndexOut] = balances[tokenIndexOut]+ tokenAmountOut
        result = finalBalanceIn - balances[tokenIndexIn] + Decimal(1/1e18)
        return result

calc_bpt_in_given_exact_tokens_out

def calc_bpt_in_given_exact_tokens_out(amplificationParameter: Decimal, balances: list, amountsOut: list, bptTotalSupply: Decimal, swapFee: Decimal) -> Decimal:
        currentInvariants = StableMath.calculateInvariant(amplificationParameter, balances)
        # calculate the sum of all token balances
        sumBalances = Decimal(0)
        for i in range(len(balances)):
            sumBalances += balances[i]

        tokenBalanceRatiosWithoutFee = [None] * len(balances)
        weightedBalanceRatio = Decimal(0)

        getcontext().prec = 28
        for i in range(len(balances)):
            currentWeight = divUp(balances[i], sumBalances)
            tokenBalanceRatiosWithoutFee[i] = balances[i] - divUp(amountsOut[i], balances[i])
            weightedBalanceRatio = weightedBalanceRatio + mulUp(tokenBalanceRatiosWithoutFee[i], currentWeight)

        newBalances = []
        for i in range(len(balances)):
            tokenBalancePercentageExcess = 0
            if weightedBalanceRatio <= tokenBalanceRatiosWithoutFee[i]:
                tokenBalancePercentageExcess = 0
            else:
                tokenBalancePercentageExcess = weightedBalanceRatio - Decimal(divUp(tokenBalanceRatiosWithoutFee[i], Decimal(complement(tokenBalanceRatiosWithoutFee[i]))))

            swapFeeExcess = mulUp(swapFee, Decimal(tokenBalancePercentageExcess))
            amountOutBeforeFee = Decimal(divUp(amountsOut[i], complement(swapFeeExcess)))
            newBalances.append(balances[i] - amountOutBeforeFee)

        # get the new invariant, taking Decimalo account swap fees
        newInvariant = StableMath.calculateInvariant(amplificationParameter, newBalances)

        # return amountBPTIn
        return bptTotalSupply * Decimal(divUp(newInvariant, complement(currentInvariants)))

calc_bpt_out_given_exact_tokens_in

def calc_bpt_out_given_exact_tokens_in(amplificationParameter: Decimal, balances: dict, amountsIn: list, bptTotalSupply: Decimal, swapFee: Decimal, swapFeePercentage: Decimal) -> Decimal:
        currentInvariant = StableMath.calculateInvariant(amplificationParameter, balances)

        sumBalances = Decimal(0)
        for i in range(len(balances)):
            sumBalances += balances[i]

        tokenBalanceRatiosWithoutFee = []
        weightedBalanceRatio = 0
        for i in range(len(balances)):
            currentWeight = divDown(Decimal(balances[i]), Decimal(sumBalances))
            tokenBalanceRatiosWithoutFee.append(balances[i] + divDown(Decimal(amountsIn[i]), Decimal(balances[i])))
            weightedBalanceRatio = weightedBalanceRatio + mulDown(tokenBalanceRatiosWithoutFee[i], currentWeight)

        tokenBalancePercentageExcess = Decimal(0)
        newBalances = []
        for i in range(len(balances)):
            if weightedBalanceRatio >= tokenBalanceRatiosWithoutFee[i]:
                tokenBalancePercentageExcess = Decimal(0)
            else:
                tokenBalancePercentageExcess = tokenBalanceRatiosWithoutFee[i] - divUp(weightedBalanceRatio, (tokenBalanceRatiosWithoutFee[i])) # TODO omitted subtracting ONE from token without fee

            swapFeeExcess = mulUp(Decimal(swapFeePercentage), tokenBalancePercentageExcess)
            amountInAfterFee = mulDown(Decimal(amountsIn[i]), Decimal(complement(swapFeeExcess)))
            newBalances.append(balances[i] + amountInAfterFee)


        newInvariant = Decimal(StableMath.calculateInvariant(amplificationParameter, newBalances))
        return Decimal(mulDown(bptTotalSupply, divDown(newInvariant, currentInvariant))) 

calc_tokens_in_given_exact_bpt_out

def calc_tokens_in_given_exact_bpt_out(amplificationParameter: Decimal, balances: list, tokenIndex: int, bptAmountOut: Decimal, bptTotalSupply: Decimal, swapFeePercentage: Decimal):

        currentInvariant = Decimal(StableMath.calculateInvariant(amplificationParameter, balances))

        newInvariant = divUp((bptTotalSupply + bptAmountOut), mulUp(bptTotalSupply, currentInvariant))

        sumBalances = Decimal(0)
        for i in range(len(balances)):
            sumBalances += i

        newBalanceTokenIndex = StableMath.getTokenBalanceGivenInvariantAndAllOtherBalances(amplificationParameter, balances, newInvariant, tokenIndex)
        amountInAfterFee = newBalanceTokenIndex - balances[tokenIndex]

        currentWeight = divDown(balances[tokenIndex], sumBalances)
        tokenBalancePercentageExcess = complement(currentWeight)
        swapFeeExcess = mulUp(swapFeePercentage, tokenBalancePercentageExcess)

        return divUp(amountInAfterFee, complement(complement(swapFeeExcess)))

calc_tokens_out_given_exact_bpt_in

def calc_tokens_out_given_exact_bpt_in(balances: list, bptAmountIn: Decimal, bptTotalSupply: Decimal) -> dict:

        bptRatio = divDown(bptAmountIn, bptTotalSupply)
        amountsOut = []

        for i in range(len(balances)):
            amountsOut.append(mulDown(balances[i], bptRatio))

        return amountsOut

get_token_balance_given_invariant_and_all_other_balances

def get_token_balance_given_invariant_and_all_other_balances(amplificationParameter: Decimal, balances: list, invariant: Decimal, tokenIndex: int) -> Decimal:
        
        ampTimesTotal = amplificationParameter * len(balances)
        bal_sum = Decimal(sum(balances))
        P_D = len(balances) * balances[0]
        for i in range(1, len(balances)):
            P_D = (P_D*balances[i]*len(balances))/invariant

        bal_sum -= balances[tokenIndex]

        c = invariant*invariant/ampTimesTotal
        c = divUp(mulUp(c, balances[tokenIndex]), P_D)
        print(type(bal_sum),type(invariant),type(ampTimesTotal))
        b = bal_sum + divDown(invariant, ampTimesTotal)
        prevTokenbalance = 0
        tokenBalance = divUp((invariant*invariant+c), (invariant+b))
        for i in range(255):
            prevTokenbalance = tokenBalance
            tokenBalance = divUp((mulUp(tokenBalance,tokenBalance) + c),((tokenBalance*Decimal(2))+b-invariant))
            print(i,tokenBalance)
            if(tokenBalance > prevTokenbalance):
                if(tokenBalance-prevTokenbalance <= 1/1e18):
                    break
            elif(prevTokenbalance-tokenBalance <= 1/1e18):
                break
        
        


        return tokenBalance

Last updated