120 lines
3.9 KiB
Python
120 lines
3.9 KiB
Python
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
from collections import defaultdict
|
|
|
|
from fuzzywuzzy import fuzz
|
|
from nltk import PorterStemmer
|
|
|
|
basepath = os.path.dirname(os.path.abspath(__file__))
|
|
FUZZ_THRESHOLD = 70
|
|
DEFAULT_RESULTS_COUNT = 5
|
|
|
|
FIELD_EMOJI = 'emoji'
|
|
FIELD_DESCRIPTION = 'description'
|
|
FIELD_CATEGORY = 'category'
|
|
FIELD_ALIASES = 'aliases'
|
|
FIELD_TAGS = 'aliases'
|
|
|
|
search_fields = [FIELD_EMOJI, FIELD_DESCRIPTION, FIELD_ALIASES, FIELD_TAGS]
|
|
|
|
logging.basicConfig()
|
|
logger = logging.getLogger()
|
|
|
|
|
|
class HashableDict(dict):
|
|
def __hash__(self):
|
|
return hash(self.get(FIELD_EMOJI))
|
|
|
|
|
|
def setup_logging_level(debug=False):
|
|
log_level = logging.DEBUG if debug else logging.ERROR
|
|
logger.setLevel(log_level)
|
|
logger.debug("Debugging enabled")
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('query', nargs='*', default="", help="freeform")
|
|
parser.add_argument('-r', dest='result', type=int, default=0)
|
|
parser.add_argument('-i', dest='case_insensitive', action='store_false')
|
|
parser.add_argument('-s', dest='is_stemming', action='store_false')
|
|
parser.add_argument('-f', dest='is_fuzzed', action='store_false')
|
|
parser.add_argument('-c', dest='results_cnt', type=int, default=DEFAULT_RESULTS_COUNT)
|
|
parser.add_argument('--debug', dest='debug', action='store_true')
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_data():
|
|
with open(os.path.join(basepath, 'data', 'emoji.json')) as f:
|
|
emoji_data = json.load(f)
|
|
rv = []
|
|
for entry in emoji_data:
|
|
search_text = ''
|
|
for field in search_fields:
|
|
field_value = entry.get(field)
|
|
if isinstance(field_value, list):
|
|
search_text += ' '.join(field_value) + ' '
|
|
else:
|
|
search_text += field_value + ' '
|
|
entry['search_text'] = search_text
|
|
rv.append(entry)
|
|
return emoji_data
|
|
|
|
|
|
def text_to_tokens(text, is_case_insensitive, is_stemming):
|
|
if is_case_insensitive:
|
|
text = text.lower()
|
|
tokens = text.split()
|
|
if is_stemming:
|
|
stemmer = PorterStemmer()
|
|
tokens = [stemmer.stem(x) for x in tokens]
|
|
return tokens
|
|
|
|
|
|
def search(query, is_case_insensitive=True, is_stemming=True, is_fuzzed=True):
|
|
query = ' '.join(query)
|
|
query_tokens = text_to_tokens(query, is_case_insensitive, is_stemming)
|
|
emoji_data = load_data()
|
|
results = defaultdict(int)
|
|
for query_token in query_tokens:
|
|
for entry in emoji_data:
|
|
doc_tokens = text_to_tokens(entry['search_text'], is_case_insensitive, is_stemming)
|
|
for doc_token in doc_tokens:
|
|
if is_fuzzed:
|
|
fuzz_ratio = fuzz.ratio(query_token, doc_token)
|
|
if fuzz_ratio > FUZZ_THRESHOLD:
|
|
results[HashableDict(entry)] += fuzz_ratio / 100
|
|
else:
|
|
if query_token == doc_token:
|
|
results[HashableDict(entry)] += 1
|
|
return sorted(results.items(), key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
def format_results(results, results_cnt):
|
|
for idx, result in enumerate(results[:results_cnt]):
|
|
print("[{}] {} - {} - (score: {})".format(idx + 1,
|
|
result[0].get(FIELD_EMOJI),
|
|
result[0].get(FIELD_DESCRIPTION),
|
|
round(result[1], 2),
|
|
))
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
setup_logging_level(args.debug)
|
|
results = search(query=args.query,
|
|
is_case_insensitive=args.case_insensitive,
|
|
is_stemming=args.is_stemming,
|
|
is_fuzzed=args.is_fuzzed,
|
|
)
|
|
if args.result and 0 < args.result <= len(results):
|
|
print(results[args.result - 1][0].get(FIELD_EMOJI))
|
|
else:
|
|
format_results(results, results_cnt=args.results_cnt)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|