Source code for zoo.examples.textclassification.news20

#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import tarfile
import zipfile
from bigdl.dataset import base
from bigdl.util.common import *

NEWS20_URL = 'http://qwone.com/~jason/20Newsgroups/20news-18828.tar.gz'
GLOVE_URL = 'http://nlp.stanford.edu/data/glove.6B.zip'


[docs]def get_news20(base_dir="./data/20news-18828"): """ Parse 20 Newsgroup dataset and return a list of (tokens, label). The dataset will be downloaded automatically if not found in the target base_dir. """ news20_dir = base_dir + "/20news-18828/" if not os.path.isdir(news20_dir): download_news20(base_dir) texts = [] label_id = 0 for category in sorted(os.listdir(news20_dir)): category_dir = os.path.join(news20_dir, category) if os.path.isdir(category_dir): for text_file in sorted(os.listdir(category_dir)): if text_file.isdigit(): text_file_path = os.path.join(category_dir, text_file) if sys.version_info < (3,): f = open(text_file_path) else: f = open(text_file_path, encoding='latin-1') content = f.read() texts.append((content, label_id)) f.close() label_id += 1 class_num = label_id print('Found %s texts.' % len(texts)) return texts, class_num
[docs]def get_glove(base_dir="./data/glove.6B", dim=100): """ Parse the pre-trained glove6B word2vec and return a dict mapping from word to vector, given the dim of a vector. The word embeddings will be downloaded automatically if not found in the target base_dir. """ glove_dir = base_dir + "/glove.6B" if not os.path.isdir(glove_dir): download_glove(base_dir) glove_path = os.path.join(glove_dir, "glove.6B.%sd.txt" % dim) if sys.version_info < (3,): w2v_f = open(glove_path) else: w2v_f = open(glove_path, encoding='latin-1') pre_w2v = {} for line in w2v_f.readlines(): items = line.split(" ") pre_w2v[items[0]] = [float(i) for i in items[1:]] w2v_f.close() return pre_w2v
[docs]def download_news20(dest_dir): news20 = "20news-18828.tar.gz" news20_path = base.maybe_download(news20, dest_dir, NEWS20_URL) tar = tarfile.open(news20_path, "r:gz") news20_dir = os.path.join(dest_dir, "20news-18828") if not os.path.exists(news20_dir): print("Extracting %s to %s" % (news20_path, news20_dir)) tar.extractall(dest_dir) tar.close()
[docs]def download_glove(dest_dir): glove = "glove.6B.zip" glove_path = base.maybe_download(glove, dest_dir, GLOVE_URL) zip_ref = zipfile.ZipFile(glove_path, 'r') glove_dir = os.path.join(dest_dir, "glove.6B") if not os.path.exists(glove_dir): print("Extracting %s to %s" % (glove_path, glove_dir)) zip_ref.extractall(glove_dir) zip_ref.close()