본문 바로가기

AIFFLE/PROJECT

[kaggle] 포켓몬 분류하기 프로젝트 (Tree 모델들의 Feature importance를 확인해보자)

728x90
학습 내용
  • Tree 모델들의 Feature importance를 확인해보자
  • Feature importance를 바탕으로 Feature를 선택하고 이에 따른 성능 변화를 확인할 수 있다

 

추가적으로 배울 수 있는 것

  • 다양한 전처리 기법들 ( 문자열 전처리를 위한 정규식, 원핫인코딩 형태로 변환)
  • 시각화 방법

 

들어가며

  • 오늘은 전설의 포켓몬 분류하기 프로젝트를 진행하였습니다.
  • 그러다가 문득 상관관계가 높은 피처를 발견하게 되었고, 이 피처의 존재유무가 성능에 어떤 영항을 끼치는 지에 대한 의문점이 생겨서 이에 대해 학습하게 되었습니다.
  • 모델은 결정트리 한가지 모델만 사용하였고, 피처 엔지니어링에 따른 변화를 확인하는 데에 중점을 두었습니다.
  • 나아가 트리모델에서의 피처 중요도를 확인하고, 시각화하는 과정을 중심으로 봐주시면 좋을 것 같습니다!

 

분석한 데이터 출처

 

Pokemon with stats

721 Pokemon with stats and types

www.kaggle.com

 

간단하게 데이터에 대해 설명해 보자면 

- # : 포켓몬 Id number. 성별이 다르지만 같은 포켓몬인 경우 등은 같은 #값을 가진다. int
- Name : 포켓몬 이름. 포켓몬 각각의 이름으로 저장되고, 800개의 포켓몬의 이름 데이터는 모두 다르다. (unique) `str`
- Type 1 : 첫 번째 속성. 속성을 하나만 가지는 경우 Type 1에 입력된다. `str`
- Type 2 : 두 번째 속성. 속성을 하나만 가지는 포켓몬의 경우 Type 2는 NaN(결측값)을 가진다. `str`
- Total : 전체 6가지 스탯의 총합. `int`
- HP : 포켓몬의 체력. `int`
- Attack : 물리 공격력. (scratch, punch 등) `int`
- Defense : 물리 공격에 대한 방어력. `int`
- Sp. Atk : 특수 공격력. (fire blast, bubble beam 등) `int`
- Sp. Def : 특수 공격에 대한 방어력. `int`
- Speed : 포켓몬 매치에 대해 어떤 포켓몬이 먼저 공격할지를 결정. (더 높은 포켓몬이 먼저 공격한다) `int`
- Generation : 포켓몬의 세대. 현재 데이터에는 6세대까지 있다. `int`
- Legendary : 전설의 포켓몬 여부. !! Target feature !! `bool`

위와 같은 칼럼으로 구성되어 있는 데이터이고 전설의 포켓몬인지 구분하는 분류 문제입니다.

 

 

전설의 포켓몬과 일반 포켓몬의 속성은 차이가 있을까?

  • 속성에 대한 변수가 Type 1, Type 2로 되어 있고 1개의 속성을 가진 포켓몬도 있고, 2개의 속성을 가진 포켓몬도 있었습니다.
  • 전설의 포켓몬과 일반 포켓몬의 속성을 비교하기 위해 type1과 type2를 하나로 합쳐주고 시각화를 진행하였습니다.
ordinary_type = pd.concat([ordinary['Type 1'], ordinary['Type 2'].fillna('nan')])
legendary_type = pd.concat([legendary['Type 1'], legendary['Type 2'].fillna('nan')])

plt.figure(figsize=(16, 7))  # 화면 해상도에 따라 그래프 크기를 조정해 주세요.

plt.subplot(211)
sns.countplot(ordinary_type, order=types).set_xlabel('')
plt.title("[Ordinary Pokemons]")

plt.subplot(212)
sns.countplot(legendary_type, order=types).set_xlabel('')
plt.title("[Legendary Pokemons]")

plt.show()

 

  • 일반 포켓몬과 전설의 포켓몬의 분류는 명확한 차이가 존재하는 것을 확인할 수 있었습니다.

 

"Total" 피처의 의미

  • total : 6개 스탯의 총합
  • stats = ["HP", "Attack", "Defense", "Sp. Atk", "Sp. Def", "Speed"]
  • total 변수는 stats 값들의 합이므로 다중공선성을 가지므로 stats 변수를 선택한다면 최소 하나의 변수는 제거하는 것이 적절하고 생각했습니다

 

포켓몬의 이름

  • 포켓몬의 이름은 총 네 가지 타입으로 나뉩니다.
    • 한 단어면 ex. Venusaur 두 단어이고, 앞 단어는 두 개의 대문자를 가지며 대문자를 기준으로 두 부분으로 나뉘는 경우 ex. VenusaurMega Venusaur
    • 이름은 두 단어이고, 맨 뒤에 X, Y로 성별을 표시하는 경우 ex. CharizardMega Charizard X
    • 알파벳이 아닌 문자를 포함하는 경우 ex. Zygarde50% Forme
    • 이름에 알파벳이 아닌 문자가 들어간 경우

전처리하기

  • 이 중 가장 먼저 '알파벳이 아닌 문자'를 포함하는 경우를 처리하도록 하겠습니다. 어떤 문자열이 알파벳으로만 이루어져 있는지를 확인하고 싶을 때는 isalpha() 함수를 사용하면 편리합니다.
  • pandas의 isalpha() 함수 우리는 알파벳이 아닌 문자를 포함하는 이름을 걸러내고 싶은데, 주의할 점은 이름에 띄어쓰기가 있는 경우에도 isalpha() = False로 처리된다는 점입니다. 따라서 알파벳 체크를 위해 띄어쓰기가 없는 컬럼을 따로 만들어준 후, 띄어쓰기를 빈칸으로 처리해서 확인하도록 하겠습니다.
pokemon["Name_nospace"] = pokemon["Name"].apply(lambda i: i.replace(" ", ""))
pokemon["name_isalpha"] = pokemon["Name_nospace"].apply(lambda i: i.isalpha())

# 데이터 확인 후 변환
pokemon = pokemon.replace(to_replace="Nidoran♀", value="Nidoran X")
pokemon = pokemon.replace(to_replace="Nidoran♂", value="Nidoran Y")
pokemon = pokemon.replace(to_replace="Farfetch'd", value="Farfetchd")
pokemon = pokemon.replace(to_replace="Mr. Mime", value="Mr Mime")
pokemon = pokemon.replace(to_replace="Porygon2", value="Porygon Two")
pokemon = pokemon.replace(to_replace="Ho-oh", value="Ho Oh")
pokemon = pokemon.replace(to_replace="Mime Jr.", value="Mime Jr")
pokemon = pokemon.replace(to_replace="Porygon-Z", value="Porygon Z")
pokemon = pokemon.replace(to_replace="Zygarde50% Forme", value="Zygarde Forme")

# 바꿔준 'Name' 컬럼으로 'Name_nospace'를 만들고, 다시 isalpha()로 체크
pokemon["Name_nospace"] = pokemon["Name"].apply(lambda i: i.replace(" ", ""))
pokemon["name_isalpha"] = pokemon["Name_nospace"].apply(lambda i: i.isalpha())
pokemon[pokemon["name_isalpha"] == False]

 

이름을 띄어쓰기 & 대문자 기준으로 분리해 토큰화

- [A-Z] : A부터 Z까지의 대문자 중 한 가지로 시작하고,
- [a-z] : 그 뒤에 a부터 z까지의 소문자 중 한 가지가 붙는데,
- * : 그 소문자의 개수는 하나 이상인 패턴 (`*`는 정규표현식 중에서 "반복"을 나타내는 기호)

import re

def tokenize(name):
    name_split = name.split(" ")
    tokens = []
    for part_name in name_split:
        a = re.findall('[A-Z][a-z]*', part_name)
        tokens.extend(a)
    return np.array(tokens)
    
all_tokens = list(legendary["Name"].apply(tokenize).values)

token_set = []
for token in all_tokens:
    token_set.extend(token)

print(len(set(token_set)))
print(token_set)
65
['Articuno', 'Zapdos', 'Moltres', 'Mewtwo', 'Mewtwo', 'Mega', 'Mewtwo', 'X', 'Mewtwo', 'Mega', 'Mewtwo', 'Y', 'Raikou', 'Entei', 'Suicune', 'Lugia', 'Ho', 'Regirock', 'Regice', 'Registeel', 'Latias', 'Latias', 'Mega', 'Latias', 'Latios', 'Latios', 'Mega', 'Latios', 'Kyogre', 'Kyogre', 'Primal', 'Kyogre', 'Groudon', 'Groudon', 'Primal', 'Groudon', 'Rayquaza', 'Rayquaza', 'Mega', 'Rayquaza', 'Jirachi', 'Deoxys', 'Normal', 'Forme', 'Deoxys', 'Attack', 'Forme', 'Deoxys', 'Defense', 'Forme', 'Deoxys', 'Speed', 'Forme', 'Uxie', 'Mesprit', 'Azelf', 'Dialga', 'Palkia', 'Heatran', 'Regigigas', 'Giratina', 'Altered', 'Forme', 'Giratina', 'Origin', 'Forme', 'Darkrai', 'Shaymin', 'Land', 'Forme', 'Shaymin', 'Sky', 'Forme', 'Arceus', 'Victini', 'Cobalion', 'Terrakion', 'Virizion', 'Tornadus', 'Incarnate', 'Forme', 'Tornadus', 'Therian', 'Forme', 'Thundurus', 'Incarnate', 'Forme', 'Thundurus', 'Therian', 'Forme', 'Reshiram', 'Zekrom', 'Landorus', 'Incarnate', 'Forme', 'Landorus', 'Therian', 'Forme', 'Kyurem', 'Kyurem', 'Black', 'Kyurem', 'Kyurem', 'White', 'Kyurem', 'Xerneas', 'Yveltal', 'Zygarde', 'Forme', 'Diancie', 'Diancie', 'Mega', 'Diancie', 'Hoopa', 'Hoopa', 'Confined', 'Hoopa', 'Hoopa', 'Unbound', 'Volcanion']

 

여기서 많이 사용된 토큰을 추출

import collections

most_common = Counter(token_set).most_common(10)
most_common
[('Forme', 15),
 ('Mega', 6),
 ('Mewtwo', 5),
 ('Kyurem', 5),
 ('Deoxys', 4),
 ('Hoopa', 4),
 ('Latias', 3),
 ('Latios', 3),
 ('Kyogre', 3),
 ('Groudon', 3)]
for token, _ in most_common:
    # pokemon[token] = ... 형식으로 사용하면 뒤에서 warning이 발생합니다
    pokemon[f"{token}"] = pokemon["Name"].str.contains(token)

pokemon.head(10)

Type 피처 boolean 형식으로 변환

types = ['Steel', 'Fairy', 'Ice', 'Bug', 'Fire', 'Normal', 'Ground', 'Ghost', 'Poison', 'Psychic', 'Electric', 'Rock', 'Flying', 'Dragon', 'Water', 'Grass', 'Dark', 'Fighting']

for t in types:
    pokemon[t] = (pokemon["Type 1"] == t) | (pokemon["Type 2"] == t)
    
pokemon[[["Type 1", "Type 2"] + types][0]].head()

원하는 대로 잘 된 것을 확인할 수 있다.

 

모델 학습 (Baseline)

features = ['Total', 'HP', 'Attack', 'Defense','Sp. Atk', 'Sp. Def', 'Speed', 'Generation', 
            'name_count','long_name', 'Forme', 'Mega', 'Mewtwo','Deoxys', 'Kyurem', 'Latias', 'Latios',
            'Kyogre', 'Groudon', 'Hoopa','Poison', 'Ground', 'Flying', 'Normal', 'Water', 'Fire',
            'Electric','Rock', 'Dark', 'Fairy', 'Steel', 'Ghost', 'Psychic', 'Ice', 'Bug', 'Grass', 'Dragon', 'Fighting']

X = pokemon[features]
y = pokemon['Legendary']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=15)

model = DecisionTreeClassifier(random_state=25)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

from sklearn.metrics import confusion_matrix
confusion_matrix(y_test, y_pred)

 

from sklearn.metrics import classification_report
print(classification_report(y_test, y_pred))

성능 개선을 위한 노력

1. Total 피처 없애기

  • 위에서 말했던 것처럼 다중공선성을 가지는 total 피처를 없애보았습니다.
X_train, X_test, y_train, y_test = train_test_split(X.drop('Total', axis=1), y, test_size=0.2, random_state=15)

model = DecisionTreeClassifier(random_state=25)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

from sklearn.metrics import confusion_matrix
confusion_matrix(y_test, y_pred)

  • 완전히 같은 모델로 성능을 확인해봤으나 성능이 확연하게 떨어지는 것을 볼 수 있었습니다.
  • 이로써 total 피처는 전설의 포켓몬을 분류하는 데 있어서 주요한 피처임을 확인할 수 있었습니다.
  • 이에 따라 상관관계를 확인하였습니다.

  • 히트맵을 통해 total 피처와 가장 높은 상관관계를 가진 Sp.Atk를 제거해봐야 겠다 생각했습니다.

2. Sp. Atk 피처 없애기

  • 코드는 위와 유사하게 진행하였습니다.
  • 결과 :

  • 피처가 모두 있었을 때 (38개) 보다 더 좋은 성능을 가지게 되었습니다!

 

실험을 하다보니 Feature Importance를 확인해보고 싶어졌습니다.

  • 이와 관련해서 코드를 찾다보니 permutation importance에 대해 알게 되었고, 피처의 변화에 따른 성능변화를 통해 피처의 중요도를 파악하는 방식이였습니다.
from sklearn.inspection import permutation_importance

def plot_permutation_importance(clf, X, y, ax):
    result = permutation_importance(clf, X, y, n_repeats=10, random_state=42, n_jobs=2)
    perm_sorted_idx = result.importances_mean.argsort()

    ax.boxplot(
        result.importances[perm_sorted_idx].T,
        vert=False,
        labels=X.columns[perm_sorted_idx],
    )
    ax.axvline(x=0, color="k", linestyle="--")
    return ax

 코드는 이와 같습니다.

clf = DecisionTreeClassifier(random_state=25)
clf.fit(X_train, y_train)
print(f"Baseline accuracy on test data: {clf.score(X_test, y_test):.2}")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

mdi_importances = pd.Series(clf.feature_importances_, index=X_train.columns)
tree_importance_sorted_idx = np.argsort(clf.feature_importances_)
tree_indices = np.arange(0, len(clf.feature_importances_)) + 0.5

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))
mdi_importances.sort_values().plot.barh(ax=ax1)
ax1.set_xlabel("Gini importance")
plot_permutation_importance(clf, X_train, y_train, ax2)
ax2.set_xlabel("Decrease in accuracy score")
fig.suptitle(
    "Impurity-based vs. permutation importances on multicollinear features (train set)"
)
_ = fig.tight_layout()

대부분의 값들의 중요도가 0이였습니다. 중요하지 않은 변수를 제거해보면 좋을 것 같습니다.

from scipy.cluster import hierarchy
from scipy.spatial.distance import squareform
from scipy.stats import spearmanr

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))
corr = spearmanr(X).correlation

# Ensure the correlation matrix is symmetric
corr = (corr + corr.T) / 2
np.fill_diagonal(corr, 1)

# 상관관계 matrix를 거리 matrix로 변환합니다
# Ward's linkage를 이용한 hierarchical clustering을 진행하겠습니다.
distance_matrix = 1 - np.abs(corr)
dist_linkage = hierarchy.ward(squareform(distance_matrix))

dendro = hierarchy.dendrogram(
    dist_linkage, labels=X.columns.to_list(), ax=ax1, leaf_rotation=90
)
dendro_idx = np.arange(0, len(dendro["ivl"]))

ax2.imshow(corr[dendro["leaves"], :][:, dendro["leaves"]])
ax2.set_xticks(dendro_idx)
ax2.set_yticks(dendro_idx)
ax2.set_xticklabels(dendro["ivl"], rotation="vertical")
ax2.set_yticklabels(dendro["ivl"])
_ = fig.tight_layout()

 

위 덴드로그램을 통해 뽑아낸 데이터의 관계를 바탕으로 변수를 선정합니다.

from collections import defaultdict

cluster_ids = hierarchy.fcluster(dist_linkage, 1, criterion="distance")
cluster_id_to_feature_ids = defaultdict(list)
for idx, cluster_id in enumerate(cluster_ids):
    cluster_id_to_feature_ids[cluster_id].append(idx)
selected_features = [v[0] for v in cluster_id_to_feature_ids.values()]
selected_features_names = X.columns[selected_features]

X_train_sel = X_train[selected_features_names]
X_test_sel = X_test[selected_features_names]

clf_sel = DecisionTreeClassifier(random_state=25)
clf_sel.fit(X_train_sel, y_train)
fig, ax = plt.subplots(figsize=(7, 6))
plot_permutation_importance(clf_sel, X_test_sel, y_test, ax)
ax.set_title("Permutation Importances on selected subset of features\n(test set)")
ax.set_xlabel("Decrease in accuracy score")
ax.figure.tight_layout()
plt.show()

성능 :

피처가 줄어들었음에도 더 좋은 성능을 가지게 되었습니다.

Plot_tree

import matplotlib.pyplot as plt
from sklearn import tree

fig = plt.figure(figsize=(15,7))
_ = tree.plot_tree(clf_sel,feature_names = clf_sel.feature_names_in_,
                   filled=True)

 

결론

  • 우연히 든 궁금증으로 여러가지에 대해  많이 고민해 볼 수 있었던 시간이었습니다.
  • 이로써 tree관련 permutation importance, feature importance, tree plot까지 공부할 수 있었던 시간이었습니다.
  • 이 페이지를 통해서 tree 모델의 성능 개선점들에 대해서 같이 생각해볼 수 있었다면 좋을 것 같습니다! 

 

 

참고자료

- https://scikit-learn.org/stable/auto_examples/ensemble/plot_forest_importances.html

- https://scikit-learn.org/stable/auto_examples/inspection/plot_permutation_importance_multicollinear.html#sphx-glr-auto-examples-inspection-plot-permutation-importance-multicollinear-py

728x90