본문 바로가기

AIFFLE/STUDY

[DL] keras API (Sequential, Functional, Subclassing) 이해하기

728x90

 

 

학습 내용
  • 케라스 API 구조를 이해한다

 

 

1. Sequential API

특징:

  • 순차적인 레이어 쌓기 방식으로 간단하고 직관적.
  • 단순한 네트워크 구조에 적합.

예제:

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D

model = Sequential([
    Conv2D(32, kernel_size=(3,3), activation='relu', input_shape=(28,28,1)),
    Flatten(),
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')
])

 

2. Functional API

특징:

  • 복잡한 모델 구축에 더 유용 (예: 멀티 인풋/아웃풋 모델, 나무 구조 등을 구현 가능).
  • 레이어를 그래프 형태로 연결하여 유연함.

예제:

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Flatten, Conv2D

inputs = Input(shape=(28,28,1))
x = Conv2D(32, kernel_size=(3,3), activation='relu')(inputs)
x = Flatten()(x)
x = Dense(128, activation='relu')(x)
outputs = Dense(10, activation='softmax')(x)

model = Model(inputs=inputs, outputs=outputs)

 

3. Subclassing API

특징:

  • 완전한 유연성을 제공하여 모델을 직접 클래스로 정의 가능.
  • 매우 복잡한 사용자 정의 동작과 반복 루프, 조건문 등을 사용할 수 있음.

예제:

class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv = Conv2D(32, kernel_size=(3,3), activation='relu')
        self.flatten = Flatten()
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.conv(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

model = MyModel()

 

 

정리

1. Sequential과 functional의 차이

  • Sequential: 간단하게 레이어를 차례로 쌓아 올리는 방식. 복잡한 모델(멀티 인풋/아웃풋, 나무 구조 등)에는 적합하지 않음.
  • Functional: 레이어를 그래프로 정의할 수 있어 복잡한 구조를 손쉽게 구현할 수 있음. 더 유연하고 확장 가능한 모델을 구축 가능.

 

2. Sequential로는 구현이 불가능한 상황

  • 멀티 인풋/아웃풋 모델 :
    • Sequential API로는 두 개 이상의 입력이나 출력을 다루기 어려움.
    • 이런 경우 Functional API가 필요.

 

3. Functional로 구현하는 것 보다 Subclassing으로 구현하는 것이 유리한 사례

  • 사용자 정의 훈련 루프나 매우 복잡한 동작이 필요한 모델 (예: GANs, 강화 학습 에이전트):
    • 유향 비순환 그래프로 표현할 수 없는 모델
    • 동적 그래프를 활용한 강화 학습 모델 강화 학습 모델, 특히 딥 Q-네트워크(DQN)이나 재귀 강화 학습 모델은 행동과 보상 기준으로 모델이 동적으로 업데이트되며, 명확한 고정 그래프 구조가 없습니다.
    • Generative Adversarial Networks (GANs)
      • 복잡한 GANs 구조에서는 Generator와 Discriminator 간의 상호작용이 순환적인 형태로 나타날 수 있음

 

subclassing으로 구현한 GAN 예시

class GAN(keras.Model):

    def __init__(self, generator, discriminator):
        super(GAN, self).__init__()
        self.generator = generator
        self.discriminator = discriminator

    def compile(self, generator_optimizer, discriminator_optimizer, loss_fn):
        super(GAN, self).compile()
        self.generator_optimizer = generator_optimizer
        self.discriminator_optimizer = discriminator_optimizer
        self.loss_fn = loss_fn

    def train_step(self, data):
        real_images, _ = data

        # 생성기 그라디언트 계산
        with tf.GradientTape() as gen_tape:
            generated_images = self.generator(tf.random.normal(shape=(batch_size, noise_dim)), training=True)
            fake_predictions = self.discriminator(generated_images, training=True)
            gen_loss = self.loss_fn(tf.ones_like(fake_predictions), fake_predictions)
            
        gradients_of_generator = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        self.generator_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))

        # 판별기 그라디언트 계산
        with tf.GradientTape() as disc_tape:
            real_predictions = self.discriminator(real_images, training=True)
            fake_predictions = self.discriminator(generated_images, training=True)
            disc_loss = (self.loss_fn(tf.ones_like(real_predictions), real_predictions) + self.loss_fn(tf.zeros_like(fake_predictions), fake_predictions)) / 2



        gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
        self.discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables))

        return {"gen_loss": gen_loss, "disc_loss": disc_loss}

# 모델, 옵티마이저, 손실 함수 정의
generator = create_generator_model()
discriminator = create_discriminator_model()
gan = GAN(generator, discriminator)

gan.compile(generator_optimizer=keras.optimizers.Adam(),
            discriminator_optimizer=keras.optimizers.Adam(),
            loss_fn=keras.losses.BinaryCrossentropy(from_logits=True))

 

참고자료

케라스 창시자에게 배우는 딥러닝

  •  
728x90