728x90
학습 내용
- 케라스 API 구조를 이해한다
1. Sequential API
특징:
- 순차적인 레이어 쌓기 방식으로 간단하고 직관적.
- 단순한 네트워크 구조에 적합.
예제:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D
model = Sequential([
Conv2D(32, kernel_size=(3,3), activation='relu', input_shape=(28,28,1)),
Flatten(),
Dense(128, activation='relu'),
Dense(10, activation='softmax')
])
2. Functional API
특징:
- 복잡한 모델 구축에 더 유용 (예: 멀티 인풋/아웃풋 모델, 나무 구조 등을 구현 가능).
- 레이어를 그래프 형태로 연결하여 유연함.
예제:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Flatten, Conv2D
inputs = Input(shape=(28,28,1))
x = Conv2D(32, kernel_size=(3,3), activation='relu')(inputs)
x = Flatten()(x)
x = Dense(128, activation='relu')(x)
outputs = Dense(10, activation='softmax')(x)
model = Model(inputs=inputs, outputs=outputs)
3. Subclassing API
특징:
- 완전한 유연성을 제공하여 모델을 직접 클래스로 정의 가능.
- 매우 복잡한 사용자 정의 동작과 반복 루프, 조건문 등을 사용할 수 있음.
예제:
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv = Conv2D(32, kernel_size=(3,3), activation='relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.conv(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
model = MyModel()
정리
1. Sequential과 functional의 차이
- Sequential: 간단하게 레이어를 차례로 쌓아 올리는 방식. 복잡한 모델(멀티 인풋/아웃풋, 나무 구조 등)에는 적합하지 않음.
- Functional: 레이어를 그래프로 정의할 수 있어 복잡한 구조를 손쉽게 구현할 수 있음. 더 유연하고 확장 가능한 모델을 구축 가능.
2. Sequential로는 구현이 불가능한 상황
- 멀티 인풋/아웃풋 모델 :
- Sequential API로는 두 개 이상의 입력이나 출력을 다루기 어려움.
- 이런 경우 Functional API가 필요.
3. Functional로 구현하는 것 보다 Subclassing으로 구현하는 것이 유리한 사례
- 사용자 정의 훈련 루프나 매우 복잡한 동작이 필요한 모델 (예: GANs, 강화 학습 에이전트):
- 유향 비순환 그래프로 표현할 수 없는 모델
- 동적 그래프를 활용한 강화 학습 모델 강화 학습 모델, 특히 딥 Q-네트워크(DQN)이나 재귀 강화 학습 모델은 행동과 보상 기준으로 모델이 동적으로 업데이트되며, 명확한 고정 그래프 구조가 없습니다.
- Generative Adversarial Networks (GANs)
- 복잡한 GANs 구조에서는 Generator와 Discriminator 간의 상호작용이 순환적인 형태로 나타날 수 있음
subclassing으로 구현한 GAN 예시
class GAN(keras.Model):
def __init__(self, generator, discriminator):
super(GAN, self).__init__()
self.generator = generator
self.discriminator = discriminator
def compile(self, generator_optimizer, discriminator_optimizer, loss_fn):
super(GAN, self).compile()
self.generator_optimizer = generator_optimizer
self.discriminator_optimizer = discriminator_optimizer
self.loss_fn = loss_fn
def train_step(self, data):
real_images, _ = data
# 생성기 그라디언트 계산
with tf.GradientTape() as gen_tape:
generated_images = self.generator(tf.random.normal(shape=(batch_size, noise_dim)), training=True)
fake_predictions = self.discriminator(generated_images, training=True)
gen_loss = self.loss_fn(tf.ones_like(fake_predictions), fake_predictions)
gradients_of_generator = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
self.generator_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))
# 판별기 그라디언트 계산
with tf.GradientTape() as disc_tape:
real_predictions = self.discriminator(real_images, training=True)
fake_predictions = self.discriminator(generated_images, training=True)
disc_loss = (self.loss_fn(tf.ones_like(real_predictions), real_predictions) + self.loss_fn(tf.zeros_like(fake_predictions), fake_predictions)) / 2
gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
self.discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables))
return {"gen_loss": gen_loss, "disc_loss": disc_loss}
# 모델, 옵티마이저, 손실 함수 정의
generator = create_generator_model()
discriminator = create_discriminator_model()
gan = GAN(generator, discriminator)
gan.compile(generator_optimizer=keras.optimizers.Adam(),
discriminator_optimizer=keras.optimizers.Adam(),
loss_fn=keras.losses.BinaryCrossentropy(from_logits=True))
참고자료
케라스 창시자에게 배우는 딥러닝
728x90
'AIFFLE > STUDY' 카테고리의 다른 글
[NLP] Attention 쉽게 이해하기 (Query, Key, Value, Transformer에서의 attention 3종류) (0) | 2024.06.24 |
---|---|
[DL] 사용자 정의 훈련 스탭 (fit 메서드 커스터마이즈 하기) (0) | 2024.06.02 |
평가지표 - accuracy, precision, recall, F score, PR curve, AUC-ROC (0) | 2024.05.23 |
[DL] 일반화 성능 향상시키기 (0) | 2024.05.21 |
텐서의 이해 (0) | 2024.05.21 |