initial commit
This commit is contained in:
commit
e296eec1ab
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
.idea
|
||||
venv
|
22753
data/emoji.json
Normal file
22753
data/emoji.json
Normal file
File diff suppressed because it is too large
Load Diff
114
main.py
Normal file
114
main.py
Normal file
@ -0,0 +1,114 @@
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
from fuzzywuzzy import fuzz
|
||||
from nltk import PorterStemmer
|
||||
|
||||
basepath = os.path.dirname(os.path.abspath(__file__))
|
||||
FUZZ_THRESHOLD = 70
|
||||
DEFAULT_RESULTS_COUNT = 5
|
||||
|
||||
FIELD_EMOJI = 'emoji'
|
||||
FIELD_DESCRIPTION = 'description'
|
||||
FIELD_CATEGORY = 'category'
|
||||
FIELD_ALIASES = 'aliases'
|
||||
FIELD_TAGS = 'aliases'
|
||||
|
||||
search_fields = [FIELD_EMOJI, FIELD_DESCRIPTION, FIELD_ALIASES, FIELD_TAGS]
|
||||
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class HashableDict(dict):
|
||||
def __hash__(self):
|
||||
return hash(self.get(FIELD_EMOJI))
|
||||
|
||||
|
||||
def setup_logging_level(debug=False):
|
||||
log_level = logging.DEBUG if debug else logging.ERROR
|
||||
logger.setLevel(log_level)
|
||||
logger.debug("Debugging enabled")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('query', nargs='*', default="", help="freeform")
|
||||
parser.add_argument('-i', dest='case_insensitive', action='store_false')
|
||||
parser.add_argument('-s', dest='is_stemming', action='store_false')
|
||||
parser.add_argument('-f', dest='is_fuzzed', action='store_false')
|
||||
parser.add_argument('-r', dest='results_cnt', type=int, default=DEFAULT_RESULTS_COUNT)
|
||||
parser.add_argument('--debug', dest='debug', action='store_true')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_data():
|
||||
with open(os.path.join(basepath, 'data', 'emoji.json')) as f:
|
||||
emoji_data = json.load(f)
|
||||
rv = []
|
||||
for entry in emoji_data:
|
||||
search_text = ''
|
||||
for field in search_fields:
|
||||
field_value = entry.get(field)
|
||||
if isinstance(field_value, list):
|
||||
search_text += ' '.join(field_value) + ' '
|
||||
else:
|
||||
search_text += field_value + ' '
|
||||
entry['search_text'] = search_text
|
||||
rv.append(entry)
|
||||
return emoji_data
|
||||
|
||||
|
||||
def text_to_tokens(text, is_case_insensitive, is_stemming):
|
||||
if is_case_insensitive:
|
||||
text = text.lower()
|
||||
tokens = text.split()
|
||||
if is_stemming:
|
||||
stemmer = PorterStemmer()
|
||||
tokens = [stemmer.stem(x) for x in tokens]
|
||||
return tokens
|
||||
|
||||
|
||||
def search(query, is_case_insensitive=True, is_stemming=True, is_fuzzed=True):
|
||||
query = ' '.join(query)
|
||||
query_tokens = text_to_tokens(query, is_case_insensitive, is_stemming)
|
||||
emoji_data = load_data()
|
||||
results = defaultdict(int)
|
||||
for query_token in query_tokens:
|
||||
for entry in emoji_data:
|
||||
doc_tokens = text_to_tokens(entry['search_text'], is_case_insensitive, is_stemming)
|
||||
for doc_token in doc_tokens:
|
||||
if is_fuzzed:
|
||||
fuzz_ratio = fuzz.ratio(query_token, doc_token)
|
||||
if fuzz_ratio > FUZZ_THRESHOLD:
|
||||
results[HashableDict(entry)] += fuzz_ratio / 100
|
||||
else:
|
||||
if query_token == doc_token:
|
||||
results[HashableDict(entry)] += 1
|
||||
return sorted(results.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
def format_results(results, results_cnt):
|
||||
for result in results[:results_cnt]:
|
||||
print("{} - {} - (score: {})".format(result[0].get(FIELD_EMOJI),
|
||||
result[0].get(FIELD_DESCRIPTION),
|
||||
result[1],
|
||||
))
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
setup_logging_level(args.debug)
|
||||
results = search(query=args.query,
|
||||
is_case_insensitive=args.case_insensitive,
|
||||
is_stemming=args.is_stemming,
|
||||
is_fuzzed=args.is_fuzzed,
|
||||
)
|
||||
format_results(results, results_cnt=args.results_cnt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@ -0,0 +1,7 @@
|
||||
click==8.0.3
|
||||
fuzzywuzzy==0.18.0
|
||||
joblib==1.1.0
|
||||
nltk==3.7
|
||||
python-Levenshtein==0.12.2
|
||||
regex==2022.1.18
|
||||
tqdm==4.62.3
|
Loading…
Reference in New Issue
Block a user