emojis/main.py
Daniel Tsvetkov 5405417c64 get result
2022-05-05 07:45:47 +02:00

120 lines
3.9 KiB
Python

import argparse
import json
import logging
import os
from collections import defaultdict
from fuzzywuzzy import fuzz
from nltk import PorterStemmer
basepath = os.path.dirname(os.path.abspath(__file__))
FUZZ_THRESHOLD = 70
DEFAULT_RESULTS_COUNT = 5
FIELD_EMOJI = 'emoji'
FIELD_DESCRIPTION = 'description'
FIELD_CATEGORY = 'category'
FIELD_ALIASES = 'aliases'
FIELD_TAGS = 'aliases'
search_fields = [FIELD_EMOJI, FIELD_DESCRIPTION, FIELD_ALIASES, FIELD_TAGS]
logging.basicConfig()
logger = logging.getLogger()
class HashableDict(dict):
def __hash__(self):
return hash(self.get(FIELD_EMOJI))
def setup_logging_level(debug=False):
log_level = logging.DEBUG if debug else logging.ERROR
logger.setLevel(log_level)
logger.debug("Debugging enabled")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('query', nargs='*', default="", help="freeform")
parser.add_argument('-r', dest='result', type=int, default=0)
parser.add_argument('-i', dest='case_insensitive', action='store_false')
parser.add_argument('-s', dest='is_stemming', action='store_false')
parser.add_argument('-f', dest='is_fuzzed', action='store_false')
parser.add_argument('-c', dest='results_cnt', type=int, default=DEFAULT_RESULTS_COUNT)
parser.add_argument('--debug', dest='debug', action='store_true')
return parser.parse_args()
def load_data():
with open(os.path.join(basepath, 'data', 'emoji.json')) as f:
emoji_data = json.load(f)
rv = []
for entry in emoji_data:
search_text = ''
for field in search_fields:
field_value = entry.get(field)
if isinstance(field_value, list):
search_text += ' '.join(field_value) + ' '
else:
search_text += field_value + ' '
entry['search_text'] = search_text
rv.append(entry)
return emoji_data
def text_to_tokens(text, is_case_insensitive, is_stemming):
if is_case_insensitive:
text = text.lower()
tokens = text.split()
if is_stemming:
stemmer = PorterStemmer()
tokens = [stemmer.stem(x) for x in tokens]
return tokens
def search(query, is_case_insensitive=True, is_stemming=True, is_fuzzed=True):
query = ' '.join(query)
query_tokens = text_to_tokens(query, is_case_insensitive, is_stemming)
emoji_data = load_data()
results = defaultdict(int)
for query_token in query_tokens:
for entry in emoji_data:
doc_tokens = text_to_tokens(entry['search_text'], is_case_insensitive, is_stemming)
for doc_token in doc_tokens:
if is_fuzzed:
fuzz_ratio = fuzz.ratio(query_token, doc_token)
if fuzz_ratio > FUZZ_THRESHOLD:
results[HashableDict(entry)] += fuzz_ratio / 100
else:
if query_token == doc_token:
results[HashableDict(entry)] += 1
return sorted(results.items(), key=lambda x: x[1], reverse=True)
def format_results(results, results_cnt):
for idx, result in enumerate(results[:results_cnt]):
print("[{}] {} - {} - (score: {})".format(idx + 1,
result[0].get(FIELD_EMOJI),
result[0].get(FIELD_DESCRIPTION),
round(result[1], 2),
))
def main():
args = parse_args()
setup_logging_level(args.debug)
results = search(query=args.query,
is_case_insensitive=args.case_insensitive,
is_stemming=args.is_stemming,
is_fuzzed=args.is_fuzzed,
)
if args.result and 0 < args.result <= len(results):
print(results[args.result - 1][0].get(FIELD_EMOJI))
else:
format_results(results, results_cnt=args.results_cnt)
if __name__ == "__main__":
main()