Kata Codewars - “Sequência Repetitiva”

Aug 17 2020

Estou tentando resolver o seguinte codewars kata .

Temos uma lista como
seq = [0, 1, 2, 2]

Teremos que escrever uma função que irá, primeiramente, adicionar elementos na lista usando a seguinte lógica.
se n = 3 e como seq[n]=2, a nova lista será seq = [0, 1, 2, 2, 3, 3]
se n = 4 e como seq[n]=3, a nova lista será ser seq = [0, 1, 2, 2, 3, 3, 4, 4, 4]
se n = 5 e como seq[n]=3, a nova lista será seq = [0, 1, 2, 2 , 3, 3, 4, 4, 4, 5, 5, 5] e assim por diante.

Então a função retornará o n-ésimo elemento da lista.

Alguns elementos da lista:
[0, 1, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8 , 8, 8, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 13, 13 , 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 17, 17 , 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20, 20, 20 , 20, 20, 21, 21, 21, 21, 21, 21, 21, 21]

Restrição para Python:
0 <= n <= 2^41

Meu código é executado com êxito em meu sistema para qualquer valor de n, incluindo n=2^41 (dentro de 2,2s). Mas o tempo limite em Codewars. Alguém pode me ajudar a otimizar meu código? Desde já, obrigado.

Meu código:

def find(n):
    arr = [0, 1, 2, 2]
    if n <= 3:
        return arr[n]
    else:
        arr_sum = 5
        for i in range(3, n+1):
            arr_sum += i * arr[i]
            if arr_sum >= n:
                x = (arr_sum - n) // i
                return len(arr) + arr[i] - (x+1)
            else:
                arr += [i] * arr[i]

Respostas

2 superbrain Aug 18 2020 at 00:13

Construir a lista leva muito tempo. Observe que você arr += [i] * arr[i]repete o mesmo valor ivárias vezes. Isso pode ser muito compactado apenas armazenando repeat(i, arr[i]). Isso ocorre em cerca de 6 segundos, bem abaixo do limite de 12 segundos:

from itertools import chain, repeat

def find(n):
    if n <= 3:
        return [0, 1, 2, 2][n]
    arr = [[2]]
    arr_sum = 5
    arr_len = 4
    for i, arr_i in enumerate(chain.from_iterable(arr), 3):
        arr_sum += i * arr_i
        if arr_sum >= n:
            x = (arr_sum - n) // i
            return arr_len + arr_i - (x+1)
        arr.append(repeat(i, arr_i))
        arr_len += arr_i

Observe que, no n > 3caso, começamos já com o número 3, acrescentando-o duas vezes à lista. Portanto, da sequência start [0, 1, 2, 2], precisamos apenas de [2], então começo com arr = [[2]](que é mais curto que [repeat(2, 1)], e chainnão importa).


Como alternativa... observe que você está estendendo arrmuito mais rápido do que consumindo. Para n=2 41 , você aumenta para mais de 51 milhões de elementos, mas na verdade está lendo menos do que os primeiros 70 mil. Portanto, você pode parar de estender a lista nesse ponto. Isso acontece em cerca de 4,7 segundos:

def find(n):
    arr = [0, 1, 2, 2]
    if n <= 3:
        return arr[n]
    arr_sum = 5
    arr_len = 4
    for i in range(3, n+1):
        arr_sum += i * arr[i]
        if arr_sum >= n:
            x = (arr_sum - n) // i
            return arr_len + arr[i] - (x+1)
        arr_len += arr[i]
        if arr_len < 70_000:
            arr += [i] * arr[i]

E... você pode combinar as duas melhorias acima, ou seja, aplicar isso if arr_len < 70_000:ao repeat-version. Isso acontece em cerca de 4,5 segundos.

Resultados de benchmark no meu PC para n=2 41 :

Your original:   1.795 seconds
My first one:    0.043 seconds (42 times faster)
My second one:   0.041 seconds (44 times faster)
The combination: 0.026 seconds (69 times faster)

Ah, e um comentário de estilo: Você faz isso duas vezes:

if ...:
    return ...
else:
    ...

O elsee o recuo de todo o código restante são desnecessários e eu o evitaria. Eu fiz isso nas soluções acima.