"""IntegerPartitions.py

Generate and manipulate partitions of integers into sums of integers.

D. Eppstein, August 2005.
"""

import unittest

def mckay(n):
    """
    Integer partitions of n, in reverse lexicographic order.
    Note that the generated output consists of the same list object,
    repeated the correct number of times; the caller must leave this
    list unchanged, and must make a copy of any partition that is
    intended to last longer than the next call into the generator.
    The algorithm follows Knuth v4 fasc3 p38 in rough outline.
    """
    if n == 0:
        yield []
    if n <= 0:
        return
    partition = [n]
    last_nonunit = (n > 1) - 1
    while True:
        yield partition
        if last_nonunit < 0:
            return
        if partition[last_nonunit] == 2:
            partition[last_nonunit] = 1
            partition.append(1)
            last_nonunit -= 1
            continue
        replacement = partition[last_nonunit] - 1
        total_replaced = replacement + len(partition) - last_nonunit
        reps,rest = divmod(total_replaced,replacement)
        partition[last_nonunit:] = reps*[replacement]
        if rest:
            partition.append(rest)
        last_nonunit = len(partition) - (partition[-1]==1) - 1

def revlex_partitions(n):
    """
    Integer partitions of n, in reverse lexicographic order.
    The output and asymptotic runtime are the same as mckay(n),
    but the algorithm is different: it involves no division,
    and is simpler than mckay, but uses O(n) extra space for
    a recursive call stack.
    """
    if n == 0:
        yield []
    if n <= 0:
        return
    for p in revlex_partitions(n-1):
        if len(p) == 1 or (len(p) > 1 and p[-1] < p[-2]):
            p[-1] += 1
            yield p
            p[-1] -= 1
        p.append(1)
        yield p
        p.pop()

def lex_partitions(n):
    """Similar to revlex_partitions, but in lexicographic order."""
    if n == 0:
        yield []
    if n <= 0:
        return
    for p in lex_partitions(n-1):
        p.append(1)
        yield p
        p.pop()
        if len(p) == 1 or (len(p) > 1 and p[-1] < p[-2]):
            p[-1] += 1
            yield p
            p[-1] -= 1

partitions = revlex_partitions     # default partition generating algorithm

def binary_partitions(n):
    """
    Generate partitions of n into powers of two, in revlex order.
    Knuth exercise 7.2.1.4.64.
    The average time per output is constant.
    But this doesn't really solve the exercise, because it isn't loopless...
    """

    # Generate the binary representation of n
    if n < 0:
        return
    pow = 1
    sum = 0
    while pow <= n:
        pow <<= 1
    partition = []
    while pow:
        if sum+pow <= n:
            partition.append(pow)
            sum += pow
        pow >>= 1
    
    # Find all partitions of numbers up to n into powers of two > 1,
    # in revlex order, by repeatedly splitting the smallest nonunit power,
    # and replacing the following sequence of 1's by the first revlex
    # partition with maximum power less than the result of the split.
    
    # Time analysis:
    #
    # Each outer iteration increases len(partition) by at most one
    # (only if the power being split is a 2) and each inner iteration
    # in which some ones are replaced by x decreases len(partition),
    # so the number of those inner iterations is less than one per
    # output.
    #
    # Each time a power 2^k is split, it creates two or more 2^{k-1}'s,
    # all of which must eventually be split as well.  So, it S_k denotes
    # the number of times a 2^k is split, and X denotes the total
    # number of outputs generated, then S_k <= X/2^{k-1}.
    # On an outer iteration in which 2^k is split, there will be k
    # inner iterations in which x is halved, so the total number
    # of such inner iterations is <= sum_k k*X/2^{k-1} = O(X).
    #
    # Therefore the overall average time per output is constant.
    
    last_nonunit = len(partition) - 1 - (n&1)
    while True:
        yield partition
        if last_nonunit < 0:
            return
        if partition[last_nonunit] == 2:
            partition[last_nonunit] = 1
            partition.append(1)
            last_nonunit -= 1
            continue
        partition.append(1)
        x = partition[last_nonunit] = partition[last_nonunit+1] = \
            partition[last_nonunit] >> 1    # make the split!
        last_nonunit += 1
        while x > 1:
            if len(partition) - last_nonunit - 1 >= x:
                del partition[-x+1:]
                last_nonunit += 1
                partition[last_nonunit] = x
            else:
                x >>= 1

def fixed_length_partitions(n,L):
    """
    Integer partitions of n into L parts, in colex order.
    The algorithm follows Knuth v4 fasc3 p38 in rough outline;
    Knuth credits it to Hindenburg, 1779.
    """
    
    # guard against special cases
    if L == 0:
        if n == 0:
            yield []
        return
    if L == 1:
        if n > 0:
            yield [n]
        return
    if n < L:
        return

    partition = [n - L + 1] + (L-1)*[1]
    while True:
        yield partition
        if partition[0] - 1 > partition[1]:
            partition[0] -= 1
            partition[1] += 1
            continue
        j = 2
        s = partition[0] + partition[1] - 1
        while j < L and partition[j] >= partition[0] - 1:
            s += partition[j]
            j += 1
        if j >= L:
            return
        partition[j] = x = partition[j] + 1
        j -= 1
        while j > 0:
            partition[j] = x
            s -= x
            j -= 1
        partition[0] = s

def conjugate(p):
    """
    Find the conjugate of a partition.
    E.g. len(p) = max(conjugate(p)) and vice versa.
    """
    result = []
    j = len(p)
    if j <= 0:
        return result
    while True:
        result.append(j)
        while len(result) >= p[j-1]:
            j -= 1
            if j == 0:
                return result
    
# If run standalone, perform unit tests

class PartitionTest(unittest.TestCase):
    counts = [1,1,2,3,5,7,11,15,22,30,42,56,77,101,135]

    def testCounts(self):
        """Check that each generator has the right number of outputs."""
        for n in range(len(self.counts)):
            self.assertEqual(self.counts[n],len(list(mckay(n))))
            self.assertEqual(self.counts[n],len(list(lex_partitions(n))))
            self.assertEqual(self.counts[n],len(list(revlex_partitions(n))))

    def testSums(self):
        """Check that all outputs are partitions of the input."""
        for n in range(len(self.counts)):
            for p in mckay(n):
                self.assertEqual(n,sum(p))
            for p in revlex_partitions(n):
                self.assertEqual(n,sum(p))
            for p in lex_partitions(n):
                self.assertEqual(n,sum(p))
    
    def testRevLex(self):
        """Check that the revlex generators' outputs are in revlex order."""
        for n in range(len(self.counts)):
            last = [n+1]
            for p in mckay(n):
                self.assert_(last > p)
                last = list(p)  # make less-mutable copy
            last = [n+1]
            for p in revlex_partitions(n):
                self.assert_(last > p)
                last = list(p)  # make less-mutable copy

    def testLex(self):
        """Check that the lex generator's outputs are in lex order."""
        for n in range(1,len(self.counts)):
            last = []
            for p in lex_partitions(n):
                self.assert_(last < p)
                last = list(p)  # make less-mutable copy

    def testRange(self):
        """Check that all numbers in output partitions are in range."""
        for n in range(len(self.counts)):
            for p in mckay(n):
                for x in p:
                    self.assert_(0 < x <= n)
            for p in lex_partitions(n):
                for x in p:
                    self.assert_(0 < x <= n)
            for p in revlex_partitions(n):
                for x in p:
                    self.assert_(0 < x <= n)
    
    def testFixedLength(self):
        """Check that the fixed length partition outputs are correct."""
        for n in range(len(self.counts)):
            pn = [list(p) for p in revlex_partitions(n)]
            pn.sort()
            np = 0
            for L in range(n+1):
                pnL = [list(p) for p in fixed_length_partitions(n,L)]
                pnL.sort()
                np += len(pnL)
                self.assertEqual(pnL,[p for p in pn if len(p) == L])
            self.assertEqual(np,len(pn))
                
    def testConjugatePartition(self):
        """Check that conjugating a partition forms another partition."""
        for n in range(len(self.counts)):
            for p in partitions(n):
                c = conjugate(p)
                for x in c:
                    self.assert_(0 < x <= n)
                self.assertEqual(sum(c),n)

    def testConjugateInvolution(self):
        """Check that double conjugation returns the same partition."""
        for n in range(len(self.counts)):
            for p in partitions(n):
                self.assertEqual(p,conjugate(conjugate(p)))

    def testConjugateMaxLen(self):
        """Check the max-length reversing property of conjugation."""
        for n in range(1,len(self.counts)):
            for p in partitions(n):
                self.assertEqual(len(p),max(conjugate(p)))

    def testBinary(self):
        """Test that the binary partitions are generated correctly."""
        for n in range(len(self.counts)):
            binaries = []
            for p in partitions(n):
                for x in p:
                    if x & (x - 1):
                        break
                else:
                    binaries.append(list(p))
            self.assertEqual(binaries,[list(p) for p in binary_partitions(n)])

if __name__ == "__main__":
    unittest.main()
