Source code for zoo.tfpark.text.estimator.bert_base
#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from zoo.tfpark.estimator import *
from bert import modeling
[docs]def bert_model(features, labels, mode, params):
"""
Return an instance of BertModel and one can take its different outputs to
perform specific tasks.
"""
import tensorflow as tf
input_ids = features["input_ids"]
if "input_mask" in features:
input_mask = features["input_mask"]
else:
input_mask = None
if "token_type_ids" in features:
token_type_ids = features["token_type_ids"]
else:
token_type_ids = None
bert_config = modeling.BertConfig.from_json_file(params["bert_config_file"])
model = modeling.BertModel(
config=bert_config,
is_training=(mode == tf.estimator.ModeKeys.TRAIN),
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=token_type_ids,
use_one_hot_embeddings=params["use_one_hot_embeddings"])
tvars = tf.trainable_variables()
if params["init_checkpoint"]:
assignment_map, initialized_variable_names = \
modeling.get_assignment_map_from_checkpoint(tvars, params["init_checkpoint"])
tf.train.init_from_checkpoint(params["init_checkpoint"], assignment_map)
return model
[docs]def bert_input_fn(rdd, max_seq_length, batch_size,
features={"input_ids", "input_mask", "token_type_ids"},
extra_features=None, labels=None, label_size=None):
"""
Takes an RDD to create the input function for BERT related TFEstimators.
For training and evaluation, each element in rdd should be a tuple:
(dict of features, a single label or dict of labels)
Note that currently only integer or integer array labels are supported.
For prediction, each element in rdd should be a dict of features.
Features in each RDD element should contain "input_ids", "input_mask" and "token_type_ids",
each of shape max_seq_length.
If you have other extra features in your dict of features, you need to explicitly specify
the argument `extra_features`, which is supposed to be the dict with feature name as key
and tuple of (dtype, shape) as its value.
"""
import tensorflow as tf
assert features.issubset({"input_ids", "input_mask", "token_type_ids"})
features_dict = {}
for feature in features:
features_dict[feature] = (tf.int32, [max_seq_length])
if extra_features is not None:
assert isinstance(extra_features, dict), "extra_features should be a dictionary"
for k, v in extra_features.items():
assert isinstance(k, six.string_types)
assert isinstance(v, tuple)
features_dict[k] = v
if label_size is None:
label_size = []
else:
label_size = [label_size]
if labels is None:
res_labels = (tf.int32, label_size)
elif isinstance(labels, list) or isinstance(labels, set):
labels = set(labels)
if len(labels) == 1:
res_labels = (tf.int32, label_size)
else:
res_labels = {}
for label in labels:
res_labels[label] = (tf.int32, label_size)
else:
raise ValueError("Wrong labels. "
"labels should be a set of label names if you have multiple labels")
def input_fn(mode):
if mode == tf.estimator.ModeKeys.TRAIN:
return TFDataset.from_rdd(rdd,
features=features_dict,
labels=res_labels,
batch_size=batch_size)
elif mode == tf.estimator.ModeKeys.EVAL:
return TFDataset.from_rdd(rdd,
features=features_dict,
labels=res_labels,
batch_per_thread=batch_size // rdd.getNumPartitions())
else:
return TFDataset.from_rdd(rdd,
features=features_dict,
batch_per_thread=batch_size // rdd.getNumPartitions())
return input_fn
[docs]class BERTBaseEstimator(TFEstimator):
"""
The base class for BERT related TFEstimators.
Common arguments:
bert_config_file, init_checkpoint, use_one_hot_embeddings, optimizer, model_dir.
For its subclass:
- One can add additional arguments and access them via `params`.
- One can utilize `_bert_model` to create model_fn and `bert_input_fn` to create input_fn.
"""
def __init__(self, model_fn, bert_config_file, init_checkpoint=None,
use_one_hot_embeddings=False, model_dir=None, **kwargs):
import tensorflow as tf
params = {"bert_config_file": bert_config_file,
"init_checkpoint": init_checkpoint,
"use_one_hot_embeddings": use_one_hot_embeddings}
for k, v in kwargs.items():
params[k] = v
estimator = tf.estimator.Estimator(model_fn, model_dir=model_dir, params=params)
super(BERTBaseEstimator, self).__init__(estimator)