학습
이제 학습을 수행하는 부분을 작성하겠습니다.
이전에 작성한 config yaml을 불러옵니다. config는 OmegaConf를 이용하여 불러옵니다.
1
2
3
|
from omegaconf import OmegaConf
config = OmegaConf.load("config/train_config.yaml")
|
dataset과 model을 생성합니다.
1
2
3
4
5
6
7
8
9
10
11
12
|
dataset = {}
dataset["train"] = CorpusDataset(
config.train_data_path, preprocessor.get_input_features
)
dataset["val"] = CorpusDataset(
config.val_data_path, preprocessor.get_input_features
)
dataset["test"] = CorpusDataset(
config.test_data_path, preprocessor.get_input_features
)
bert_finetuner = SpacingBertModel(config, dataset)
|
logging과 학습에 사용할 callback을 작성합니다.
- checkpoint_callback : 매 epoch마다 validation loss를 계산해 loss가 줄어드는 경우 checkpoint를 저장합니다.
- early_stop_callback : validation loss가 더 이상 줄어들지 않으면 학습을 종료합니다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
logger = TensorBoardLogger(
save_dir=os.path.join(config.log_path, config.task), version=1, name=config.task
)
checkpoint_callback = ModelCheckpoint(
filepath="checkpoints/"+ config.task + "/{epoch}_{val_loss:35f}",
verbose=True,
monitor="val_loss",
mode="min",
save_top_k=3,
prefix="",
)
early_stop_callback = EarlyStopping(
monitor="val_loss",
min_delta=0.0001,
patience=3,
verbose=False,
mode="min",
)
|
마지막으로 학습을 수행하고 테스트를 진행합니다.
PyTorch Lightning은 distributed training을 지원하기때문에 DistributedDataParallel을 이용해 학습을 하겠습니다.
1
2
3
4
5
6
7
8
9
10
|
trainer = pl.Trainer(
gpus=config.gpus,
distributed_backend=config.distributed_backend,
checkpoint_callback=checkpoint_callback,
early_stop_callback=early_stop_callback,
logger=logger,
)
trainer.fit(bert_finetuner)
trainer.test()
|
결과
데이터셋
모두의 말뭉치 뉴스 데이터 전체를 학습하기에 시간이 오래걸려서 일부 데이터로 학습을 진행하였습니다.
- train : 100,000건
- val : 10,000건
- test : 10,000건
그래프



성능
-
F1
1
2
3
4
5
|
report: precision recall f1-score support
B 0.97 0.96 0.97 120963
I 0.91 0.91 0.91 112427
micro avg 0.94 0.94 0.94 233390
macro avg 0.94 0.94 0.94 233390
|
-
accuracy : 0.5294
결과 예시
1
2
|
gt : 특별한 일들이 생겨나고 있다.
pred : 특별한 일들이 생겨나고 있다.
|
1
2
|
gt : 몸싸움과 반격에 능한 차두리의 오버래핑은 이란의 강한 압박을 뚫고 주도권을 우리에게 유리하게 잡는데 필수적이다.
pred : 몸싸움과 반격에 능한 차두리의 오버래핑은 이란의 강한 압박을 뚫고 주도권을 우리에게 유리하게 잡는데 필수적이다.
|
1
2
|
gt : 2003년에 비교해 2010년 한국 사회는 어떤 모습인지?
pred : 2003년에 비교해 2010년 한국사회는 어떤 모습인지?
|
1
2
|
gt : 강원도 태백의 해발 935m인 삼수령 마루에 적혀있는 글이다.
pred : 강원도 태백의 해발 935m인 삼수령마루에 적혀 있는 글이다.
|
Reference