Дообучение предобученной NER модели

Подскажите пожалуйста, всё верно я делаю?

Моя задача такая: Я хочу взять предобученную модель ner_rus_bert, но дообучить её на своей задаче NER, с помощью train_model(если это возможно для моей модели вообще). Моя задача NER это задача распознавания сущностей Имя, Фамиля Отчество, Адрес и Пол. Исходная модель была предобученна на распознавании сущностей: Человек, Локация, Организация.
Что я делаю: Я получаю config модели ner_rus_bert с помощью parse_config, меняю dataset_reader data_path, в папке на которую я сменил, создаю свой датасет(test.txt, train.txt, valid.txt) в формате “Слово\tМетка\n” где отедльные предложения разделены пустой строкой, а Метки - это уже МОИ метки(например “B-NAM”, “B-SUR”), не те, на чем предобучалась модель, ну кроме метки “O”. В конфиге еще меняю save_path, load_path у simple_vocab и torch_transformers_sequence_tagger на свои и по сути запускаю train_model(model_config, download=True) далее сохраняю модель через .save(), следовательно вопрос, оно так обучится на мою задачу? Обучиться ли на мои метки при этом стартуя с предобученных весов исходной модели?
Я обучил таким образом модель на датасете с ФИО, однако пока модель путается в Фамилии, Имя и Отчество, что выглядет, будто она обучалась плохо, а может и вообще не обучалась… Может еще проблема в моём train датасете, но на всякий случай решил уточнить.

Здравствуйте! По вашему описанию вы всё делаете корректно. Перепроверьте что в папке с сохранённой моделью лежит файл tag.dict, в нём должны содержаться новые метки которые были указаны в вашем датасете. Можете также проверить логи обучения, если loss падает, то модель обучается корректно, а плохие предсказания могут возникать из-за недостатка обучающих данных или плохо подобранных параметров обучения. Если loss не уменьшается, то ошибку стоит поискать в изменённом torch_transformers_sequence_tagger.

Возможно в вашем случае будет лучше обучить модель с нуля не используя веса скачанные через конфиг ner_rus_bert.

1 Like

Спасибо большое, правда я не менял torch_transformers_sequence_tagger, я оставил всё практически кроме data_path. С данными действительно проблема, нужно пробовать обучить на хороших данных.