R: ma trận với các mũi tên chỉ hướng

Jan 03 2021

Tôi đang cố gắng tái tạo với R một thuật toán được mô tả trong Sutton và Barto (2018), nhưng tôi không thể tạo ra ma trận có các mũi tên như thuật toán được các tác giả mô tả trên trang 65:

Tôi đã cố gắng sử dụng gói "trường" cho mục đích này, nhưng không thành công nhiều.

Trong Python, giải pháp do Shangtong Zhang và Kenta Shimada đề xuất dựa vào việc sử dụng các ký hiệu mũi tên: ACTIONS_FIGS = ['←', '↑', '→', '↓'] nhưng điều này không hoạt động tốt với R ...

CHỈNH SỬA: Tôi đã viết mã các hành động ban đầu và hành động cập nhật bằng số như sau:

library(data.table)
action_random = data.table(cell=c(1:25))
action_random$action_up = action_random$action_right = action_random$action_down = action_random$action_left = rep(1,25)
action_random$proba = rep(1/4,25)
action_random

Tôi cũng có thể điều chỉnh mã được đăng ở đây , để vẽ một lưới đơn giản với các mũi tên đơn giản:

arrows = matrix(c("\U2190","\U2191","\U2192","\U2193"),nrow=2,ncol=2)
grid_arrows = expand.grid(x=1:ncol(arrows),y=1:nrow(arrows))
grid_arrows$val = arrows[as.matrix(grid_arrows[c('y','x')])]

library(ggplot2)

ggplot(grid_arrows, aes(x=x, y=y, label=val)) + 
  geom_tile(fill='transparent', colour = 'black') + 
  geom_text(size = 14) + 
  scale_y_reverse() +
  theme_classic() + 
  theme(axis.text  = element_blank(),
        panel.grid = element_blank(),
        axis.line  = element_blank(),
        axis.ticks = element_blank(),
        axis.title = element_blank())

Tuy nhiên:
(i) Không có sẵn unicode cho 2 mũi tên 4 hướng được báo cáo trong Bảng$\pi_\ast$ở trên
(ii) ... và vì vậy tôi đã không cố gắng viết mã phân tách giữa các giá trị số trong Bảng "action_random" và một Bảng đẹp có các mũi tên trong đó ...

Mọi gợi ý giúp giải quyết các vấn đề (i) và (ii) đều được hoan nghênh.

Trả lời

1 AbdurRohman Jan 21 2021 at 22:01

Đây là cách lưới + lưới để tái tạo ma trận:

library(grid)
library(lattice)

grid.newpage()
pushViewport(viewport(width = 0.8, height = 0.8)) 
grid.rect(width = 1, height = 1)
panel.grid(h = 4, v = 4)

direct = function(xCenter, yCenter, type){
  
  d= 0.05
  
  north = function(xCenter, yCenter){ 
    grid.curve(xCenter, yCenter-d ,xCenter, yCenter+d, 
               ncp = 1, angle = 90, gp=gpar(lwd=1, fill="black"),
               inflect = FALSE, shape = 0,
               arrow = arrow(type="closed", ends = "last", 
                      angle = 30, length = unit(0.2, "cm")))}
  
  west = function(xCenter, yCenter){
    grid.curve(xCenter+d, yCenter ,xCenter-d, yCenter, 
               ncp = 1, angle = 90, gp=gpar(lwd=1, fill="black"),
               inflect = FALSE, shape = 0,
               arrow = arrow(type="closed", ends = "last", 
                             angle = 30, length = unit(0.2, "cm")))}
  east = function(xCenter, yCenter){
    grid.curve(xCenter+d, yCenter ,xCenter-d, yCenter, 
               ncp = 1, angle = 90, gp=gpar(lwd=1, fill="black"),
               inflect = FALSE, shape = 0,
               arrow = arrow(type="closed", ends = "first", 
                             angle = 30, length = unit(0.2, "cm")))}
  
  northeast = function(xCenter, yCenter){
       grid.curve(xCenter-d, yCenter+d ,xCenter+d, yCenter-d, 
                 ncp = 1, angle = 90, gp=gpar(lwd=1, fill="black"),
                 inflect = FALSE, shape = 0,
                 arrow = arrow(type="closed", ends = "both", 
                         angle = 30, length = unit(0.2, "cm")))}
  
  northwest = function(xCenter, yCenter){
       grid.curve(xCenter-d, yCenter-d ,xCenter+d, yCenter+d, 
               ncp = 1, angle = 90, gp=gpar(lwd=1, fill="black"),
               inflect = FALSE, shape = 0,
               arrow = arrow(type="closed", ends = "both", 
                             angle = 30, length = unit(0.2, "cm")))}
  all = function(xCenter, yCenter){
      grid.curve(xCenter+d, yCenter ,xCenter-d, yCenter, 
                 ncp = 1, angle = 90, gp=gpar(lwd=1, fill="black"),
                 inflect = FALSE, shape = 0,
                 arrow = arrow(type="closed", ends = "both", 
                               angle = 30, length = unit(0.2, "cm")))
      grid.curve(xCenter, yCenter-d ,xCenter, yCenter+d, 
             ncp = 1, angle = 90, gp=gpar(lwd=1, fill="black"),
             inflect = FALSE, shape = 0,
             arrow = arrow(type="closed", ends = "both", 
                           angle = 30, length = unit(0.2, "cm")))}
  switch(type,
         'n' = north(xCenter, yCenter),
         'e' = east(xCenter, yCenter),
         'w' = west(xCenter, yCenter),
         'nw'= northwest(xCenter, yCenter),
         'ne' = northeast(xCenter, yCenter),
         'all' = all(xCenter, yCenter)
         )
}

x = seq(0.1, 0.9, by = 0.2)
y = x
centers = expand.grid(x0 = x, y0 = y)

row1 = row2 = row3 = c('ne','n', rep('nw',3))
row4 = c('ne','n','nw','w','w')
row5 = c('e','all','w','all','w')

dir = c(row1,row2,row3,row4,row5)
df = data.frame(centers, dir)

for (k in 1:nrow(df)) direct(df$x0[k], df$y0[k], df$dir[k])
grid.text(bquote(~pi["*"]), y = -0.05)

5 ViviG Jan 14 2021 at 21:33

Việc sử dụng gói emojifontgiúp tôi có được nhiều tùy chọn unicode hơn. Trong ggplot của bạn, bạn thêm family='EmojiOne'. Đây là một ví dụ sử dụng unicode

Thông tin thêm về gói biểu tượng cảm xúc tại đây

CHỈNH SỬA : Hack cho mũi tên 4 hướng:

Không phải là giải pháp đẹp nhất hoặc thanh lịch hơn, nhưng bạn có thể phủ các ggplots bằng cách sử dụng gói magickđể có được các mũi tên định hướng. Tạo hai lớp cốt truyện, một lớp có mũi tên trái-phải ( U+2194) và một lớp khác có mũi tên lên xuống ( U+2195), sau đó hợp nhất (cảm ơn @ Billy34 vì đã làm cho mã đẹp hơn một chút):

library(data.table)
library(magick)

library(ggplot2)
library(emojifont)

#layer 1
arrows1 = matrix(c("\U21B4","\U2195","\U2192","\U2193"),nrow=2,ncol=2)
grid_arrows1 = expand.grid(x=1:ncol(arrows1),y=1:nrow(arrows1))
grid_arrows1$val = arrows1[as.matrix(grid_arrows1[c('y','x')])] #layer 2 arrows2 = matrix(c("\U21B4","\U2194","\U2192","\U2193"),nrow=2,ncol=2) grid_arrows2 = expand.grid(x1=1:ncol(arrows2),y1=1:nrow(arrows2)) grid_arrows2$val = arrows2[as.matrix(grid_arrows2[c('y1','x1')])]

ggplot(grid_arrows1, aes(x=x, y=y, label=val),family='EmojiOne') + 
  geom_tile(fill='NA', colour = 'black') + 
  
  geom_text(size = 18) + 
  
  geom_text(grid_arrows2,mapping =  aes(x=x1, y=y1, label=val),size = 18) +
  scale_y_reverse() +
  theme_classic() + 
  theme(
        panel.background = element_rect(fill = "transparent"), # bg of the panel
    plot.background = element_rect(fill = "transparent", color = NA), # bg of the plot
axis.text  = element_blank(),
        panel.grid = element_blank(),
        axis.line  = element_blank(),
        axis.ticks = element_blank(),
        axis.title = element_blank()# get rid of legend panel bg
  ) 
  
#save plot as image
ggsave(filename = 'plot1.png', device = 'png', bg = 'transparent') 


# read images with package magick
 plot1 <- image_read('plot1.png')

image_mosaic(plot1)

CẬP NHẬT:

Cũng giống như mã trước đó, nhưng gần hơn với những gì bạn đang tìm kiếm…

Một số mã Unicode nhất định chỉ hoạt động với một số phông chữ nhất định, vì vậy bước đầu tiên là tìm phông chữ nào phù hợp với Unicode mà bạn đang tìm kiếm. Đây là một ví dụ về hỗ trợ phông chữ cho một loại mũi tên sang trái được sử dụng trong ví dụ bên dưới .

Tất nhiên, không có phông chữ nào trong danh sách là tiêu chuẩn, bởi vì cuộc sống không dễ dàng như vậy. Vì vậy, bước tiếp theo là cài đặt phông chữ. Tôi đã sử dụng phông chữ Symbola mà tôi đã tải xuống ở đây . Sao chép tệp phông chữ vào thư mục R của bạn hoặc vào thư mục dự án của bạn nếu bạn đang sử dụng các dự án.

Sau đó sử dụng thư viện showtext . Gói cho phép bạn sử dụng phông chữ hệ thống trong đồ họa (yêu cầu gói sysfonts). Nếu phông chữ là tiêu chuẩn trong hệ điều hành của bạn, tôi khuyên bạn nên xem gói systemfonts .

Trong ví dụ của tôi, tôi sử dụng các mũi tên \U1F800\U1F801, sau đó, như trong ví dụ trước đây của tôi, tôi chồng chéo họ ( PS: bạn có thể phải đánh lừa xung quanh với nudge_ynudge_xtrong geom_textđể chúng liên kết đúng) :

library(data.table)
library(magick)
library(ggplot2)
library(showtext)



#layer 1, upwards arrow
arrows1 = matrix(c("", "\U1F801", "\U1F801", ""),
                 nrow = 2,
                 ncol = 2)
grid_arrows1 = expand.grid(x = 1:ncol(arrows1), y = 1:nrow(arrows1))
grid_arrows1$val = arrows1[as.matrix(grid_arrows1[c('y', 'x')])] #layer 2 , leftwards arrow arrows2 = matrix(c("", "\U1F800", "\U1F800", ""), nrow = 2, ncol = 2) grid_arrows2 = expand.grid(x1 = 1:ncol(arrows2), y1 = 1:nrow(arrows2)) grid_arrows2$val = arrows2[as.matrix(grid_arrows2[c('y1', 'x1')])]

#layer 3 , upwards arrow
arrows3  = matrix(c("\U1F801", "", "", "\U1F801"),
                  nrow = 2,
                  ncol = 2)
grid_arrows3 = expand.grid(x2 = 1:ncol(arrows3), y2 = 1:nrow(arrows3))
grid_arrows3$val = arrows3[as.matrix(grid_arrows3[c('y2', 'x2')])] #layer 4 , leftwards arrow arrows4 = matrix(c("\U1F800", "", "", "\U1F800"), nrow = 2, ncol = 2) grid_arrows4 = expand.grid(x3 = 1:ncol(arrows4), y3 = 1:nrow(arrows4)) grid_arrows4$val = arrows4[as.matrix(grid_arrows4[c('y3', 'x3')])]

#use function font_add from lybrary showtext
 font_add("Symbola", regular = "Symbola_hint.ttf")
# Take a look at the function showtext_auto() as well

 ggplot(grid_arrows1,
        aes(x = x, y = y, label = val),
        family = 'Symbola',
        size = 18) +
   
   geom_tile(fill = 'NA', colour = 'black') +
   geom_text(
     grid_arrows1,
     mapping = aes(x = x, y = y, label = val),
     family = 'Symbola',
     size = 18
   ) +
   
   geom_text(
     grid_arrows2,
     mapping =  aes(x = x1, y = y1, label = val),
     family = 'Symbola',
     size = 18,
     nudge_x = -0.01
   ) +
   geom_text(
     grid_arrows1,
     mapping =  aes(x = x, y = y, label = val),
     family = 'Symbola',
     size = 18,
     angle = 180
   ) +
   geom_text(
     grid_arrows2,
     mapping =  aes(x = x1, y = y1, label = val),
     family = 'Symbola',
     size = 18,
     angle = 180,
     nudge_x = 0.01,
     nudge_y = 0.007
   ) +
   geom_text(
     grid_arrows3,
     mapping =  aes(x = x2, y = y2, label = val),
     family = 'Symbola',
     size = 17,
     nudge_y = 0.03
   ) +
   geom_text(
     grid_arrows4,
     mapping =  aes(x = x3, y = y3, label = val),
     family = 'Symbola',
     size = 17,
     nudge_x = -0.021,
     nudge_y = -0.01
   ) +
   
   scale_y_reverse() +
   theme_classic() +
   theme(
     panel.background = element_rect(fill = "transparent"),
     # bg of the panel
     plot.background = element_rect(fill = "transparent", color = NA),
     # bg of the plot
     axis.text  = element_blank(),
     panel.grid = element_blank(),
     axis.line  = element_blank(),
     axis.ticks = element_blank(),
     axis.title = element_blank()# get rid of legend panel bg
   ) 
 
 #save plot as image
 ggsave(filename = 'plot.png',
        device = 'png',
        bg = 'transparent')
 
 # read images with package magick
 image_read('plot.png')

Đây là kết quả tôi nhận được:

Tôi không thể nói rằng đây là đoạn mã đẹp nhất từng thấy, nó cũng dễ bị hack, nhưng nó có thể hữu ích! (Phải mất nhiều thời gian để hoàn thành việc này hơn tôi muốn thừa nhận!)