R velocidade de data.table

Aug 19 2020

Eu tenho um problema de desempenho específico, que desejo estender de forma mais geral, se possível.

Contexto:

Eu tenho brincado no google colab com um exemplo de código python para um agente Q-Learning, que associa um estado e uma ação a um valor usando um defaultdict:

self._qvalues = defaultdict(lambda: defaultdict(lambda: 0))
return self._qvalues[state][action]

Não é um especialista, mas pelo que entendi, ele retorna o valor ou adiciona e retorna 0 se a chave não for encontrada.

estou adaptando parte disso em R.
o problema é que não sei quantas combinações de estado/valores tenho e, tecnicamente, não devo saber quantos estados, acho.
No começo eu fui pelo caminho errado, com o rbindde data.frames e isso foi muito lento.
Em seguida, substituí meu objeto R por um arquivo data.frame(state, action, value = NA_real). funciona, mas ainda é muito lento. outro problema é que meu objeto data.frame tem o tamanho máximo que pode ser problemático no futuro.
então eu mudei data.framepara a data.table, que me deu o pior desempenho, então eu finalmente o indexei por (estado, ação).

qvalues <- data.table(qstate = rep(seq(nbstates), each = nbactions),
                        qaction = rep(seq(nbactions), times = nbstates),
                        qvalue = NA_real_,
                        stringsAsFactors =  FALSE)
setkey(qvalues, "qstate", "qaction")

Problema:

Comparando googlecolab/python com minha implementação R local, o google realiza acesso 1000x10e4 ao objeto em, digamos, 15s, enquanto meu código realiza acesso 100x100 em 28s. Eu obtive melhorias de 2s por compilação de bytes, mas isso ainda é muito ruim.

Usando profvis, vejo que a maior parte do tempo é gasto acessando o data.table nessas duas chamadas:

qval <- self$qvalues[J(state, action), nomatch = NA_real_]$qvalue
self$qvalues[J(state, action)]$qvalue <- value

Eu realmente não sei o que o Google tem, mas minha área de trabalho é uma fera. Também vi alguns benchmarks afirmando que data.tableera mais rápido que pandas, então suponho que o problema esteja na minha escolha do contêiner.

Perguntas:

  1. meu uso de um data.table está errado e pode ser corrigido para melhorar e corresponder à implementação do python?
  2. outro design é possível para evitar declarar todas as combinações de estado/ações que podem ser um problema se as dimensões se tornarem muito grandes?
  3. eu vi sobre o pacote de hash, é o caminho a percorrer?

Muito obrigado por qualquer indicação!

ATUALIZAR:

obrigado por toda a entrada. Então, o que fiz foi substituir 3 acessos ao meu data.table usando suas sugestões:

#self$qvalues[J(state, action)]$qvalue <- value
self$qvalues[J(state, action), qvalue := value]
#self$qvalues[J(state, action),]$qvalue <- 0
self$qvalues[J(state, action), qvalue := 0]
#qval <- self$qvalues[J(state, action), nomatch = NA_real_]$qvalue
qval <- self$qvalues[J(state, action), nomatch = NA_real_, qvalue]

isso reduziu o tempo de execução de 33s para 21s, o que é uma grande melhoria, mas ainda é extremamente lento em comparação com a defaultdictimplementação do python.

Observei o seguinte:
trabalhando em lote: acho que não consigo, pois a chamada da função depende da chamada anterior.
peudospin> Vejo que você está surpreso que a obtenção seja demorada. eu também, mas é o que o profvis afirma:

e aqui o código da função como referência:

QAgent$set("public", "get_qvalue", function( state, action) {
  #qval <- self$qvalues[J(state, action), nomatch = NA_real_]$qvalue
  qval <- self$qvalues[J(state, action), nomatch = NA_real_, qvalue]
  if (is.na(qval)) {
    #self$qvalues[self$qvalues$qstate == state & self$qvalues$qaction == action,]$qvalue <- 0
    #self$qvalues[J(state, action),]$qvalue <- 0
    self$qvalues[J(state, action), qvalue := 0]
    return(0)
  }
  return(qval)
})

Neste ponto, se não houver mais sugestões, concluirei que o data.table é muito lento para esse tipo de tarefa e devo procurar usar um envou um collections. (conforme sugerido lá: R pesquisa rápida de item único da lista vs data.table vs hash )

CONCLUSÃO:

Troquei o data.tablepor a collections::dicte o gargalo desapareceu completamente.

Respostas

2 pseudospin Aug 19 2020 at 02:45

data.tableé rápido para fazer pesquisas e manipulações em tabelas de dados muito grandes, mas não será rápido para adicionar linhas uma a uma como os dicionários python. Eu esperaria que copiasse toda a tabela cada vez que você adicionasse uma linha que claramente não é o que você deseja.

Você pode tentar usar ambientes (que são algo como um hashmap) ou, se realmente quiser fazer isso em R, pode precisar de um pacote especializado, aqui está um link para uma resposta com algumas opções.

AbdessabourMtk Aug 19 2020 at 04:53

referência

library(data.table)
Sys.setenv('R_MAX_VSIZE'=32000000000) # add to the ram limit
setDTthreads(threads=0) # use maximum threads possible
nbstates <- 1e3
nbactions <- 1e5

cartesian <- function(nbstates,nbactions){
    x= data.table(qstate=1:nbactions)
    y= data.table(qaction=1:nbstates)
    k = NULL
    x = x[, c(k=1, .SD)]
    setkey(x, k)
    y = y[, c(k=1, .SD)]
    setkey(y, NULL)
    x[y, allow.cartesian=TRUE][, c("k", "qvalue") := list(NULL, NA_real_)][]
}

#comparing seq with `:`
(bench = microbenchmark::microbenchmark(
    1:1e9,
    seq(1e9),
    times=1000L
    ))
#> Unit: nanoseconds
#>        expr  min   lq     mean median   uq   max neval
#>     1:1e+09  120  143  176.264  156.0  201  5097  1000
#>  seq(1e+09) 3039 3165 3333.339 3242.5 3371 21648  1000
ggplot2::autoplot(bench)


(bench = microbenchmark::microbenchmark(
    "Cartesian product" = cartesian(nbstates,nbactions),
    "data.table assignement"=qvalues <- data.table(qstate = rep(seq(nbstates), each = nbactions),
                        qaction = rep(seq(nbactions), times = nbstates),
                        qvalue = NA_real_,
                        stringsAsFactors =  FALSE),
    times=100L))
#> Unit: seconds
#>                    expr      min       lq     mean   median       uq      max neval
#>       Cartesian product 3.181805 3.690535 4.093756 3.992223 4.306766 7.662306  100
#>  data.table assignement 5.207858 5.554164 5.965930 5.895183 6.279175 7.670521  100
#>      data.table  (1:nb) 5.006773 5.609738 5.828659  5.80034 5.979303 6.727074  100
#>  
#>  
ggplot2::autoplot(bench)

é claro que usar seqconsome mais tempo do que chamar o 1:nb. além disso, usar um produto cartesiano torna o código mais rápido mesmo quando 1:nbé usado