vprdb.global_localization
Pipeline for global localization allows to build matches between database and queries using different VPR systems.
1# Copyright (c) 2023, Ivan Moskalenko, Anastasiia Kornilova 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14""" 15Pipeline for global localization allows to build matches between database and queries using different VPR systems. 16""" 17from vprdb.global_localization.global_localization import GlobalLocalization 18 19__all__ = ["GlobalLocalization"]
class
GlobalLocalization:
24class GlobalLocalization: 25 """ 26 Allows to create a predictor based on a given database 27 and one of the methods of global localization. 28 Results can be improved with SuperGlue. 29 """ 30 31 def __init__( 32 self, 33 global_extractor: CosPlace | NetVLAD, 34 source_db: Database, 35 local_matcher: Optional[SuperGlue] = None, 36 ): 37 self.__global_extractor = global_extractor 38 self.__local_matcher = local_matcher 39 self.__source_db = source_db 40 41 print("Calculating of global descriptors for source DB") 42 self.source_global_descs = self.__global_extractor.get_database_descriptors( 43 self.__source_db 44 ) 45 self.faiss_index = faiss.IndexFlatL2(self.source_global_descs.shape[1]) 46 self.faiss_index.add(self.source_global_descs) 47 if self.__local_matcher is not None: 48 print("Calculating of local features for source DB") 49 self.source_local_features = self.__local_matcher.get_database_features( 50 self.__source_db 51 ) 52 53 def predict(self, query_database: Database, k_closest: int = 1) -> list[int]: 54 """ 55 Predicts query matches 56 :param query_database: The database for which the predictions will be calculated 57 :param k_closest: Specifies how many predictions for each query the global localization should make. 58 If this value is greater than 1, the best match will be chosen with local matcher 59 :return: Indexes of frames from the database, corresponding to the query frames 60 """ 61 if k_closest < 1: 62 raise ValueError("K closest value can't be below 1") 63 elif k_closest > 1 and self.__local_matcher is None: 64 raise ValueError( 65 "You can't use K closest value > 1 because you don't have SuperGlue local matcher" 66 ) 67 68 print("Calculating of global descriptors") 69 queries_global_descs = self.__global_extractor.get_database_descriptors( 70 query_database 71 ) 72 _, global_predictions = self.faiss_index.search(queries_global_descs, k_closest) 73 74 if k_closest == 1: 75 return [prediction[0] for prediction in global_predictions] 76 else: 77 res_predictions = [] 78 print("Calculating of local features") 79 queries_local_descs = self.__local_matcher.get_database_features( 80 query_database 81 ) 82 print("Matching of local features") 83 for i, query in enumerate( 84 tqdm(queries_local_descs, total=len(query_database)) 85 ): 86 global_query_predictions = global_predictions[i] 87 filtered_db_features = self.source_local_features[ 88 global_query_predictions 89 ] 90 local_prediction = self.__local_matcher.match_feature( 91 query, filtered_db_features 92 ) 93 res_predictions.append(global_query_predictions[local_prediction]) 94 return res_predictions
Allows to create a predictor based on a given database and one of the methods of global localization. Results can be improved with SuperGlue.
GlobalLocalization( global_extractor: vprdb.vpr_systems.cos_place.cos_place.CosPlace | vprdb.vpr_systems.netvlad.netvlad.NetVLAD, source_db: vprdb.core.database.Database, local_matcher: Optional[vprdb.vpr_systems.superglue.superglue.SuperGlue] = None)
31 def __init__( 32 self, 33 global_extractor: CosPlace | NetVLAD, 34 source_db: Database, 35 local_matcher: Optional[SuperGlue] = None, 36 ): 37 self.__global_extractor = global_extractor 38 self.__local_matcher = local_matcher 39 self.__source_db = source_db 40 41 print("Calculating of global descriptors for source DB") 42 self.source_global_descs = self.__global_extractor.get_database_descriptors( 43 self.__source_db 44 ) 45 self.faiss_index = faiss.IndexFlatL2(self.source_global_descs.shape[1]) 46 self.faiss_index.add(self.source_global_descs) 47 if self.__local_matcher is not None: 48 print("Calculating of local features for source DB") 49 self.source_local_features = self.__local_matcher.get_database_features( 50 self.__source_db 51 )
53 def predict(self, query_database: Database, k_closest: int = 1) -> list[int]: 54 """ 55 Predicts query matches 56 :param query_database: The database for which the predictions will be calculated 57 :param k_closest: Specifies how many predictions for each query the global localization should make. 58 If this value is greater than 1, the best match will be chosen with local matcher 59 :return: Indexes of frames from the database, corresponding to the query frames 60 """ 61 if k_closest < 1: 62 raise ValueError("K closest value can't be below 1") 63 elif k_closest > 1 and self.__local_matcher is None: 64 raise ValueError( 65 "You can't use K closest value > 1 because you don't have SuperGlue local matcher" 66 ) 67 68 print("Calculating of global descriptors") 69 queries_global_descs = self.__global_extractor.get_database_descriptors( 70 query_database 71 ) 72 _, global_predictions = self.faiss_index.search(queries_global_descs, k_closest) 73 74 if k_closest == 1: 75 return [prediction[0] for prediction in global_predictions] 76 else: 77 res_predictions = [] 78 print("Calculating of local features") 79 queries_local_descs = self.__local_matcher.get_database_features( 80 query_database 81 ) 82 print("Matching of local features") 83 for i, query in enumerate( 84 tqdm(queries_local_descs, total=len(query_database)) 85 ): 86 global_query_predictions = global_predictions[i] 87 filtered_db_features = self.source_local_features[ 88 global_query_predictions 89 ] 90 local_prediction = self.__local_matcher.match_feature( 91 query, filtered_db_features 92 ) 93 res_predictions.append(global_query_predictions[local_prediction]) 94 return res_predictions
Predicts query matches
Parameters
- query_database: The database for which the predictions will be calculated
- k_closest: Specifies how many predictions for each query the global localization should make. If this value is greater than 1, the best match will be chosen with local matcher
Returns
Indexes of frames from the database, corresponding to the query frames