【深度学习 走进tensorflow2.0】使用RNN进行文本分类

此文本分类教程在IMDB大电影评论数据集上训练递归神经网络,以进行情感分析。

主要步骤:

一、文本数据集处理。
二、文本模型训练
三、文本模型预测

1、模型训练:

# -*- coding: utf-8 -*-



from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow_datasets as tfds
import tensorflow as tf
import os


dataset, info = tfds.load('imdb_reviews/subwords8k', with_info=True,as_supervised=True)
train_dataset, test_dataset = dataset['train'], dataset['test']


encoder = info.features['text'].encoder
print ('Vocabulary size: {}'.format(encoder.vocab_size))


BUFFER_SIZE = 10000
BATCH_SIZE = 64


# 随机打乱
train_dataset = train_dataset.shuffle(BUFFER_SIZE)


# 电影评论的长度可以不同。我们将使用该padded_batch方法来规范评论的长度。
train_dataset = train_dataset.padded_batch(BATCH_SIZE, train_dataset.output_shapes)


test_dataset = test_dataset.padded_batch(BATCH_SIZE, test_dataset.output_shapes)



# model = tf.keras.Sequential([
#     tf.keras.layers.Embedding(encoder.vocab_size, 64),
#     tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),
#     tf.keras.layers.Dense(64, activation='relu'),
#     tf.keras.layers.Dense(1, activation='sigmoid')
# ])

model = tf.keras.Sequential([
    tf.keras.layers.Embedding(encoder.vocab_size, 64),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64,  return_sequences=True)),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

model.compile(loss='binary_crossentropy',
              optimizer=tf.keras.optimizers.Adam(1e-4),
              metrics=['accuracy'])


model_name = 'model_ex-{epoch:03d}_acc-{val_accuracy:03f}.h5'

trained_model_dir='./model/'
model_path = os.path.join(trained_model_dir, model_name)


checkpoint = tf.keras.callbacks.ModelCheckpoint(
             filepath=model_path,
             monitor='val_accuracy',
            verbose=1,
            save_weights_only=False,
            save_best_only=True,
            mode='max',
            period=1)



history = model.fit(train_dataset, epochs=100,
                    validation_data=test_dataset,
                    validation_steps=30,verbose=1,callbacks=[checkpoint])


test_loss, test_acc = model.evaluate(test_dataset)

print('Test Loss: {}'.format(test_loss))
print('Test Accuracy: {}'.format(test_acc))

2、模型预测:


# -*- coding: utf-8 -*-


from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow_datasets as tfds
import tensorflow as tf


dataset, info = tfds.load('imdb_reviews/subwords8k', with_info=True,as_supervised=True)
encoder = info.features['text'].encoder


# model = tf.keras.Sequential([
#     tf.keras.layers.Embedding(encoder.vocab_size, 64),
#     tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),
#     tf.keras.layers.Dense(64, activation='relu'),
#     tf.keras.layers.Dense(1, activation='sigmoid')
# ])

model = tf.keras.Sequential([
    tf.keras.layers.Embedding(encoder.vocab_size, 64),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64,  return_sequences=True)),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

model.compile(loss='binary_crossentropy',
              optimizer=tf.keras.optimizers.Adam(1e-4),
              metrics=['accuracy'])



model.load_weights('./model/xxxxxxxxxxxxxxx.h5')


def pad_to_size(vec, size):
  zeros = [0] * (size - len(vec))
  vec.extend(zeros)
  return vec



def sample_predict(sample_pred_text, pad):
  encoded_sample_pred_text = encoder.encode(sample_pred_text)

  if pad:
    encoded_sample_pred_text = pad_to_size(encoded_sample_pred_text, 64)
  encoded_sample_pred_text = tf.cast(encoded_sample_pred_text, tf.float32)
  predictions = model.predict(tf.expand_dims(encoded_sample_pred_text, 0))

  return (predictions)

if __name__ == '__main__':

    sample_pred_text = ('The movie was cool. The animation and the graphics '
                        'were out of this world. I would recommend this movie.')
    predictions = sample_predict(sample_pred_text, pad=True)
    print (predictions)
已标记关键词 清除标记
©️2020 CSDN 皮肤主题: 编程工作室 设计师:CSDN官方博客 返回首页