Phân tích dữ liệu lớn - Cây quyết định

Cây quyết định là một thuật toán được sử dụng cho các vấn đề học tập có giám sát như phân loại hoặc hồi quy. Cây quyết định hoặc cây phân loại là một cây trong đó mỗi nút bên trong (không phải nút) được gắn nhãn với một tính năng đầu vào. Các cung đến từ một nút được gắn nhãn đối tượng được gắn nhãn với từng giá trị có thể có của đối tượng. Mỗi lá của cây được gắn nhãn với một lớp hoặc một phân bố xác suất trên các lớp.

Một cây có thể được "học" bằng cách tách tập nguồn thành các tập con dựa trên kiểm tra giá trị thuộc tính. Quá trình này được lặp lại trên mỗi tập con dẫn xuất theo cách đệ quy được gọi làrecursive partitioning. Quá trình đệ quy được hoàn thành khi tập hợp con tại một nút có tất cả cùng giá trị của biến mục tiêu hoặc khi việc tách không còn thêm giá trị vào các dự đoán. Quá trình quy nạp từ trên xuống của cây quyết định là một ví dụ của thuật toán tham lam và nó là chiến lược phổ biến nhất để học cây quyết định.

Cây quyết định được sử dụng trong khai thác dữ liệu có hai loại chính:

  • Classification tree - khi phản hồi là một biến danh nghĩa, chẳng hạn như email có phải là thư rác hay không.

  • Regression tree - khi kết quả dự đoán có thể được coi là một con số thực (ví dụ tiền lương của một công nhân).

Cây quyết định là một phương pháp đơn giản và như vậy có một số vấn đề. Một trong những vấn đề này là phương sai cao trong các mô hình kết quả mà cây quyết định tạo ra. Để giảm bớt vấn đề này, các phương pháp tổng hợp về cây quyết định đã được phát triển. Có hai nhóm phương pháp tổng hợp hiện đang được sử dụng rộng rãi -

  • Bagging decision trees- Những cây này được sử dụng để xây dựng nhiều cây quyết định bằng cách lấy mẫu lại nhiều lần dữ liệu huấn luyện với sự thay thế và bỏ phiếu các cây để có dự đoán đồng thuận. Thuật toán này đã được gọi là rừng ngẫu nhiên.

  • Boosting decision trees- Tăng cường kết hợp học yếu; trong trường hợp này, cây quyết định thành một người học giỏi duy nhất, theo kiểu lặp đi lặp lại. Nó phù hợp với một cây yếu với dữ liệu và lặp đi lặp lại tiếp tục phù hợp với những người học yếu để sửa lỗi của mô hình trước đó.

# Install the party package
# install.packages('party') 
library(party) 
library(ggplot2)  

head(diamonds) 
# We will predict the cut of diamonds using the features available in the 
diamonds dataset. 
ct = ctree(cut ~ ., data = diamonds) 

# plot(ct, main="Conditional Inference Tree") 
# Example output 
# Response:  cut  
# Inputs:  carat, color, clarity, depth, table, price, x, y, z  

# Number of observations:  53940  
#  
# 1) table <= 57; criterion = 1, statistic = 10131.878 
#   2) depth <= 63; criterion = 1, statistic = 8377.279 
#     3) table <= 56.4; criterion = 1, statistic = 226.423 
#       4) z <= 2.64; criterion = 1, statistic = 70.393 
#         5) clarity <= VS1; criterion = 0.989, statistic = 10.48 
#           6) color <= E; criterion = 0.997, statistic = 12.829 
#             7)*  weights = 82  
#           6) color > E  

#Table of prediction errors 
table(predict(ct), diamonds$cut) 
#            Fair  Good Very Good Premium Ideal 
# Fair       1388   171        17       0    14 
# Good        102  2912       499      26    27 
# Very Good    54   998      3334     249   355 
# Premium      44   711      5054   11915  1167 
# Ideal        22   114      3178    1601 19988 
# Estimated class probabilities 
probs = predict(ct, newdata = diamonds, type = "prob") 
probs = do.call(rbind, probs) 
head(probs)