Source code for zoo.tfpark.text.estimator.bert_classifier

#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from zoo.tfpark.text.estimator import *


[docs]def make_bert_classifier_model_fn(optimizer): def _bert_classifier_model_fn(features, labels, mode, params): """ Model function for BERTClassifier. :param features: Dict of feature tensors. Must include the key "input_ids". :param labels: Label tensor for training. :param mode: 'train', 'eval' or 'infer'. :param params: Must include the key "num_classes". :return: tf.estimator.EstimatorSpec. """ import tensorflow as tf from zoo.tfpark import ZooOptimizer output_layer = bert_model(features, labels, mode, params).get_pooled_output() hidden_size = output_layer.shape[-1].value output_weights = tf.get_variable( "output_weights", [params["num_classes"], hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable( "output_bias", [params["num_classes"]], initializer=tf.zeros_initializer()) with tf.variable_scope("loss"): if mode == tf.estimator.ModeKeys.TRAIN: output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) probabilities = tf.nn.softmax(logits, axis=-1) if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode=mode, predictions=probabilities) else: log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(labels, depth=params["num_classes"], dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) loss = tf.reduce_mean(per_example_loss) if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec(mode=mode, predictions=probabilities, loss=loss) else: train_op = ZooOptimizer(optimizer).minimize(loss) return tf.estimator.EstimatorSpec(mode=mode, train_op=train_op, loss=loss) return _bert_classifier_model_fn
[docs]class BERTClassifier(BERTBaseEstimator): """ A pre-built TFEstimator that takes the hidden state of the first token of BERT to do classification. :param num_classes: Positive int. The number of classes to be classified. :param bert_config_file: The path to the json file for BERT configurations. :param init_checkpoint: The path to the initial checkpoint of the pre-trained BERT model if any. Default is None. :param use_one_hot_embeddings: Boolean. Whether to use one-hot for word embeddings. Default is False. :param optimizer: The optimizer used to train the estimator. It should be an instance of tf.train.Optimizer. Default is None if no training is involved. :param model_dir: The output directory for model checkpoints to be written if any. Default is None. """ def __init__(self, num_classes, bert_config_file, init_checkpoint=None, use_one_hot_embeddings=False, optimizer=None, model_dir=None): super(BERTClassifier, self).__init__( model_fn=make_bert_classifier_model_fn(optimizer), bert_config_file=bert_config_file, init_checkpoint=init_checkpoint, use_one_hot_embeddings=use_one_hot_embeddings, model_dir=model_dir, num_classes=num_classes)