uni memo

機械学習を用いたA/Bテスト - チュートリアル -

機械学習を用いたA/Bテスト - チュートリアル -

こちらの記事 A/B Testing with Machine Learning – A Step-by-Step Tutorial /| R-bloggers の流れを追ってみた

問題設定+ソースコードを追ってみた。Rで行う

背景

従来の統計的推論では

treatment/control group
(処置群と統制群)のみ比較していたが

実際、ユーザの行動はもっと複雑であり

サイトでの行動は様々な(ユーザの特性とか)に起因していると考えられる

そのような要因を得るために、機械学習の適用を試みる

問題設定

udacityのデータを用いる

  • Experiment Name: “Free Trial” Screener

Udacityのサイトにおいて

1週間で5時間未満学習する人は挫折し、 1週間で5時間以上学習する人は継続する傾向があることがわかっている

start free trial
を行ったユーザに対して、学習できる時間を入力させるpop upを表示する。自身に学習時間をコミットさせるために書かせる

このフォームの表示有無をA/Bテストし、登録にいかに寄与するかを測る

また、free trial期間を最適化する(無料期間中に学習を終わらせない、かつ、学習を挫折させずに継続させる)ことを目的とする

データと中身の確認

kaggleからダウンロードする

各々データの中身は以下

  • Date: 月日と、曜日の日付データ
  • Pageviews: PV数
  • Clicks: クリック数
  • Enrollments: 登録会員数
  • Payments: 支払い数

また、要点を抜き出すと

  • データは各々37行
  • 日別データ。ユーザの振る舞いを見るには粒度が荒いが、今回はこれで
  • 日付データがcharactor型になっているので、何らかの形で変換してやる必要がある。後ほど、曜日を抽出する
  • 各データは形式が同じ、かつ日付も同じなので、A/Bテストが同時に行われた時のデータであろう

予測データの抽出には

  • 曜日情報のみ予測器に用いる
  • 登録会員数を回帰で予測する
  • 欠損値は除く

を行う

実装

データ読み込み

  • パッケージ
# Core packages
library(tidyverse)
library(tidyquant)

# Modeling packages
library(parsnip)
library(recipes)
library(rsample)
library(yardstick)
library(broom)

# Connector packages
library(rpart)
library(rpart.plot)
library(xgboost)
  • 処置群/統制群データの読み込み
# load data
control_df <- read_csv("data/raw/control_data.csv")
experiment_df <- read_csv("data/raw/experiment_data.csv")

データ確認&加工

データの質を確認する

dplyr
のパイプで作成する

欠損値は14個ある

# 欠損値の確認
# naの数を数えて、みやすく整える
control_df %>%
  map_df(~sum(is.na(.))) %>%
  gather(key = "feature", value = "missing_count") %>%
  arrange(desc(missing_count))

experiment_df %>%
  map_df(~sum(is.na(.))) %>%
  gather(key = "feature", value = "missing_count") %>%
  arrange(desc(missing_count))
# A tibble: 5 x 2
  feature     missing_count
  <chr>               <int>
1 Enrollments            14
2 Payments               14
3 Date                    0
4 Pageviews               0
5 Clicks                  0

また、同じ時期欠損していることがわかる

control_df %>%
  filter(is.na(Enrollments))

experiment_df %>%
  filter(is.na(Enrollments))

データ加工

今回のターゲットはEnrollmentsなのでPaymentsは削除する

# データの加工
set.seed(42)
formatted_df <- control_df %>%
  
  # Combine with Experiment data
  bind_rows(experiment_df, .id = "Experiment") %>%
  # control_dfのみ-1(=0)になる
  mutate(Experiment = as.numeric(Experiment) - 1) %>%
  
  # Add row id 上から順にIDを割り振る
  mutate(row_id = row_number()) %>%
  
  # Create a Day of Week feature
  # 曜日抽出、最初の3文字を抜き出す
  mutate(DOW = str_sub(Date, start = 1, end = 3) %>% 
          factor(levels = c("Sun", "Mon", "Tue", "Wed", 
                             "Thu", "Fri", "Sat"))
  ) %>%
  # いらない行削除
  select(-Date, -Payments) %>%
  
  # Remove missing data
  filter(!is.na(Enrollments)) %>%
  
  # Shuffle the data (note that set.seed is used to make reproducible)
  # データ混合。seedを設定して、再現性確保
  sample_frac(size = 1) %>%
  
  # データの列位置いじる
  select(row_id, Enrollments, Experiment, everything())

こんな感じになる

# A tibble: 74 x 8
   Experiment Date        Pageviews Clicks Enrollments Payments row_id DOW  
        <dbl> <chr>           <dbl>  <dbl>       <dbl>    <dbl>  <int> <fct>
 1          0 Sat, Oct 11      7723    687         134       70      1 Sat  
 2          0 Sun, Oct 12      9102    779         147       70      2 Sun  
 3          0 Mon, Oct 13     10511    909         167       95      3 Mon  

学習、予測データに分ける

# 学習、予測データに分ける8:2の割合
set.seed(42)
split_obj <- formatted_df %>%
  initial_split(prop = 0.8, strata = "Experiment")

train_df <- training(split_obj)
test_df  <- testing(split_obj)

機械学習モデルの適用

以下を試す

  1. Linear Regression – Linear, Explainable (Baseline)
  2. Decision Tree
    Pros: Non-Linear, Explainable
    Cons: Lower Performance
  3. XGBoost
    Pros: Non-Linear, High Performance
    Cons: Less Explainable

parsnip
パッケージを用いて、パイプラインっぽくモデル定義していく

その前に結果確認用関数を定義しておく

# メトリクスの計算
calc_metrics <- function(model, new_data) {
  model %>%
    predict(new_data = new_data) %>%
    bind_cols(new_data %>% select(Enrollments)) %>%
    metrics(truth = Enrollments, 
            estimate = .pred)
}

# 予測値の描画
plot_predictions <- function(model, new_data) {
  
  g <- predict(model, new_data) %>%
    bind_cols(new_data %>% select(Enrollments)) %>%
    mutate(observation = row_number() %>% as.character()) %>%
    gather(key = "key", value = "value", -observation, factor_key = TRUE) %>%
    
    # Visualize
    ggplot(aes(x = observation, y = value, color = key)) +
    geom_point() +
    expand_limits(y = 0) +
    theme_tq() +
    scale_color_tq()
  
  return(g)
}

Linear Regression – Linear, Explainable (Baseline)

# 1. 線形回帰
# 学習 statsパッケージのlmを使用する
model_01_lm <- linear_reg("regression") %>%
  set_engine("lm") %>%
  fit(Enrollments ~ ., data = train_df %>% select(-row_id))

# 予測
model_01_lm %>% 
  calc_metrics(test_df) %>%
  knitr::kable()

# 可視化
model_01_lm %>% 
  plot_predictions(test_df) +
  labs(title = "Enrollments: Prediction vs Actual",
       subtitle = "Model 01: Linear Regression (Baseline)")

Experiment(0 or 1)の係数が-9.6なので

Experiment群は一日経つと、Enrollmentsが-9.6減ると言える

回帰における重要な特徴量(説明変数)を確認する際には、p-valueを用いる

ここでp-valueは

係数が0である
という帰無仮説についての検定における値である

p-valueが0.05より小さいと有意差有り、すなわち予測式で重要と言える

結果を見るとClicks, Pageviewsが重要と出ている

linear_regression_model_terms_tbl <- model_01_lm$fit %>%
  # p-valueを取得
  tidy() %>%
  arrange(p.value) %>%
  mutate(term = as_factor(term) %>% fct_rev()) 

linear_regression_model_terms_tbl %>% knitr::kable()

linear_regression_model_terms_tbl %>%
  ggplot(aes(x = p.value, y = term)) +
  geom_point(color = "#2C3E50") +
  geom_vline(xintercept = 0.05, linetype = 2, color = "red") +
  theme_tq() +
  labs(title = "Feature Importance",
       subtitle = "Model 01: Linear Regression (Baseline)")

Decision Tree


# 決定木
model_02_decision_tree <- decision_tree(
  mode = "regression",
  cost_complexity = 0.001, 
  tree_depth = 5, 
  min_n = 4) %>%
  set_engine("rpart") %>%
  fit(Enrollments ~ ., data = train_df %>% select(-row_id))

model_02_decision_tree %>% 
  calc_metrics(test_df) %>%
  knitr::kable()

model_02_decision_tree %>% 
  plot_predictions(test_df) +
  labs(title = "Enrollments: Prediction vs Actual",
       subtitle = "Model 02: Decision Tree")
# 決定木の可視化
model_02_decision_tree$fit %>%
  rpart.plot(
    roundint = FALSE, 
    cex = 0.8, 
    fallen.leaves = TRUE,
    extra = 101, 
    main = "Model 02: Decision Tree")

決定木の可視化。図よりExperimentが0.5より小さい(=統制群である)

とEnrollmentsが大きい傾向がみて取れる

XGBoost

# XGboost
set.seed(42)
model_03_xgboost <- boost_tree(
  mode = "regression",
  mtry = 100, 
  trees = 1000, 
  min_n = 8, 
  tree_depth = 6, 
  learn_rate = 0.2, 
  loss_reduction = 0.01, 
  sample_size = 1) %>%
  set_engine("xgboost") %>%
  fit(Enrollments ~ ., data = train_df %>% select(-row_id))

model_03_xgboost %>% 
  calc_metrics(test_df) %>%
  knitr::kable()

model_03_xgboost %>% plot_predictions(test_df) +
  labs(title = "Enrollments: Prediction vs Actual",
       subtitle = "Model 02: Decision Tree")
       

上2つよりもよくなってる

importanceをみてみる

# importance
xgboost_feature_importance_tbl <- model_03_xgboost$fit %>%
  xgb.importance(model = .) %>%
  as_tibble() %>%
  mutate(Feature = as_factor(Feature) %>% fct_rev())

xgboost_feature_importance_tbl %>% knitr::kable()
|Feature    |      Gain|     Cover| Frequency|
|:----------|---------:|---------:|---------:|
|Pageviews  | 0.5885359| 0.5444576| 0.5329480|
|Clicks     | 0.3567912| 0.3486039| 0.3425819|
|Experiment | 0.0546729| 0.1069385| 0.1244701|

PV数+クリック数がほとんどのgainの93%を占めており

会員を集めるなら、PV数を増やすといい、という(当然の)結果が見える

xgboost_feature_importance_tbl %>%
  ggplot(aes(x = Gain, y = Feature)) +
  geom_point(color = "#2C3E50") +
  geom_label(aes(label = scales::percent(Gain)), 
             hjust = "inward", color = "#2C3E50") +
  expand_limits(x = 0) +
  theme_tq() +
  labs(title = "XGBoost Feature Importance") 

結論

MLを使うことにより効果があったか否かだけでなく

どの変数が影響を及ぼすか+その影響度合いもみれる

今回は、pop upを出すことで会員登録数を減らす効果があることがわかった

さらに、PV数と、クリック数が重要なこと、pop upをだすと1日で-9.6人、減るということがわかった

会員の増加に重きをおくか、学習時間をコミットさせることによる効果をさらに深掘りするかはまた、別な話になる

用途

いろんな属性データを追加して、寄与している(もしくはしていない)特徴量

をみることで、行った施策の影響を数字で考えたり、

新しく施策を考える材料として使うことができそう

参考

2024, Built with Gatsby. This site uses Google Analytics.