vprdb.vpr_systems
The vpr_systems
contains a set of tools for the VPR task.
1# Copyright (c) 2023, Ivan Moskalenko, Anastasiia Kornilova 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14""" The `vpr_systems` contains a set of tools for the VPR task. """ 15from vprdb.vpr_systems.cos_place import CosPlace 16from vprdb.vpr_systems.netvlad import NetVLAD 17from vprdb.vpr_systems.superglue import SuperGlue 18 19__all__ = ["CosPlace", "NetVLAD", "SuperGlue"]
class
CosPlace:
39class CosPlace: 40 """ 41 Implementation of [CosPlace](https://github.com/gmberton/CosPlace) global localization method. 42 """ 43 44 def __init__(self, backbone: str, fc_output_dim: int, path_to_weights: str): 45 self.backbone = backbone 46 self.fc_output_dim = fc_output_dim 47 self.path_to_weights = path_to_weights 48 49 self.model = GeoLocalizationNet(backbone, fc_output_dim) 50 model_state_dict = torch.load(self.path_to_weights) 51 self.model.load_state_dict(model_state_dict) 52 53 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 54 self.model.to(self.device) 55 56 def get_database_descriptors(self, database: Database): 57 """ 58 Gets database RGB images CosPlace descriptors 59 :param database: Database for getting descriptors 60 :return: Descriptors for database images 61 """ 62 self.model.eval() 63 image_providers = database.color_images 64 with torch.no_grad(): 65 all_descriptors = np.empty( 66 (len(image_providers), self.fc_output_dim), dtype="float32" 67 ) 68 for i, image in tqdm( 69 enumerate(image_providers), total=len(image_providers) 70 ): 71 image_bgr = image.color_image 72 image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) 73 base_transform = torchvision.transforms.Compose( 74 [ 75 torchvision.transforms.ToTensor(), 76 torchvision.transforms.Normalize( 77 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 78 ), 79 ] 80 ) 81 normalized_img = base_transform(image_rgb) 82 normalized_img = normalized_img[None, :] 83 descriptor = self.model(normalized_img.to(self.device)) 84 descriptor = descriptor.cpu().numpy() 85 all_descriptors[i] = descriptor 86 return all_descriptors 87 88 def fine_tune_model( 89 self, 90 target_db: Database, 91 valid_db: Database, 92 train_db: Database, 93 save_dir: str, 94 voxel_size=0.3, 95 random_resize=(480, 640), 96 brightness=0.7, 97 contrast=0.7, 98 saturation=0.7, 99 hue=0.5, 100 random_resized_crop=0.5, 101 num_workers=0, 102 batch_size=8, 103 lr=0.00001, 104 classifiers_lr=0.01, 105 patience=10, 106 max_epochs=-1, 107 seed=0, 108 ) -> str: 109 """ 110 Fine-tunes the CosPlace model for given target database 111 :param target_db: The database for which the model will be fine-tuned 112 :param valid_db: Validation database 113 :param train_db: Training database 114 :param save_dir: Directory for saving output model and log 115 :param voxel_size: Voxel size for down sampling point clouds 116 :return: Path to output model 117 """ 118 min_bounds, max_bounds = find_bounds_for_multiple_databases( 119 [target_db, valid_db, train_db] 120 ) 121 voxel_grid = VoxelGrid(min_bounds, max_bounds, voxel_size) 122 123 groups = create_groups(train_db, target_db, voxel_grid) 124 groups_lens = [len(group_db) for group_db, _ in groups] 125 val_matches = match_two_databases(valid_db, target_db, voxel_grid) 126 127 make_deterministic(seed=seed) 128 129 data = DataModule( 130 groups, 131 target_db, 132 valid_db, 133 val_matches, 134 random_resize, 135 brightness, 136 contrast, 137 saturation, 138 hue, 139 random_resized_crop, 140 num_workers, 141 batch_size, 142 ) 143 train_model = CosPlaceTrainer( 144 self.model, 145 self.fc_output_dim, 146 groups_lens, 147 len(target_db), 148 len(valid_db), 149 lr, 150 classifiers_lr, 151 self.device, 152 ) 153 154 checkpoint_callback = ModelCheckpoint( 155 dirpath=save_dir, filename="best_model", save_weights_only=True 156 ) 157 # start training 158 trainer = Trainer( 159 accelerator="auto", 160 devices=[0], 161 max_epochs=max_epochs, 162 callbacks=[ 163 EarlyStopping(monitor="R_1", mode="max", patience=patience), 164 checkpoint_callback, 165 ], 166 default_root_dir=save_dir, 167 ) 168 169 trainer.fit( 170 train_model, 171 datamodule=data, 172 ) 173 174 # Transform weights to PyTorch format 175 model_state_dict = torch.load(checkpoint_callback.best_model_path) 176 model_state_dict = model_state_dict["state_dict"] 177 new_model_state_dict = dict() 178 for k in model_state_dict.keys(): 179 new_model_state_dict[k[6:]] = model_state_dict[k] 180 torch.save(new_model_state_dict, checkpoint_callback.best_model_path) 181 return checkpoint_callback.best_model_path
Implementation of CosPlace global localization method.
CosPlace(backbone: str, fc_output_dim: int, path_to_weights: str)
44 def __init__(self, backbone: str, fc_output_dim: int, path_to_weights: str): 45 self.backbone = backbone 46 self.fc_output_dim = fc_output_dim 47 self.path_to_weights = path_to_weights 48 49 self.model = GeoLocalizationNet(backbone, fc_output_dim) 50 model_state_dict = torch.load(self.path_to_weights) 51 self.model.load_state_dict(model_state_dict) 52 53 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 54 self.model.to(self.device)
56 def get_database_descriptors(self, database: Database): 57 """ 58 Gets database RGB images CosPlace descriptors 59 :param database: Database for getting descriptors 60 :return: Descriptors for database images 61 """ 62 self.model.eval() 63 image_providers = database.color_images 64 with torch.no_grad(): 65 all_descriptors = np.empty( 66 (len(image_providers), self.fc_output_dim), dtype="float32" 67 ) 68 for i, image in tqdm( 69 enumerate(image_providers), total=len(image_providers) 70 ): 71 image_bgr = image.color_image 72 image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) 73 base_transform = torchvision.transforms.Compose( 74 [ 75 torchvision.transforms.ToTensor(), 76 torchvision.transforms.Normalize( 77 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 78 ), 79 ] 80 ) 81 normalized_img = base_transform(image_rgb) 82 normalized_img = normalized_img[None, :] 83 descriptor = self.model(normalized_img.to(self.device)) 84 descriptor = descriptor.cpu().numpy() 85 all_descriptors[i] = descriptor 86 return all_descriptors
Gets database RGB images CosPlace descriptors
Parameters
- database: Database for getting descriptors
Returns
Descriptors for database images
def
fine_tune_model( self, target_db: vprdb.core.database.Database, valid_db: vprdb.core.database.Database, train_db: vprdb.core.database.Database, save_dir: str, voxel_size=0.3, random_resize=(480, 640), brightness=0.7, contrast=0.7, saturation=0.7, hue=0.5, random_resized_crop=0.5, num_workers=0, batch_size=8, lr=1e-05, classifiers_lr=0.01, patience=10, max_epochs=-1, seed=0) -> str:
88 def fine_tune_model( 89 self, 90 target_db: Database, 91 valid_db: Database, 92 train_db: Database, 93 save_dir: str, 94 voxel_size=0.3, 95 random_resize=(480, 640), 96 brightness=0.7, 97 contrast=0.7, 98 saturation=0.7, 99 hue=0.5, 100 random_resized_crop=0.5, 101 num_workers=0, 102 batch_size=8, 103 lr=0.00001, 104 classifiers_lr=0.01, 105 patience=10, 106 max_epochs=-1, 107 seed=0, 108 ) -> str: 109 """ 110 Fine-tunes the CosPlace model for given target database 111 :param target_db: The database for which the model will be fine-tuned 112 :param valid_db: Validation database 113 :param train_db: Training database 114 :param save_dir: Directory for saving output model and log 115 :param voxel_size: Voxel size for down sampling point clouds 116 :return: Path to output model 117 """ 118 min_bounds, max_bounds = find_bounds_for_multiple_databases( 119 [target_db, valid_db, train_db] 120 ) 121 voxel_grid = VoxelGrid(min_bounds, max_bounds, voxel_size) 122 123 groups = create_groups(train_db, target_db, voxel_grid) 124 groups_lens = [len(group_db) for group_db, _ in groups] 125 val_matches = match_two_databases(valid_db, target_db, voxel_grid) 126 127 make_deterministic(seed=seed) 128 129 data = DataModule( 130 groups, 131 target_db, 132 valid_db, 133 val_matches, 134 random_resize, 135 brightness, 136 contrast, 137 saturation, 138 hue, 139 random_resized_crop, 140 num_workers, 141 batch_size, 142 ) 143 train_model = CosPlaceTrainer( 144 self.model, 145 self.fc_output_dim, 146 groups_lens, 147 len(target_db), 148 len(valid_db), 149 lr, 150 classifiers_lr, 151 self.device, 152 ) 153 154 checkpoint_callback = ModelCheckpoint( 155 dirpath=save_dir, filename="best_model", save_weights_only=True 156 ) 157 # start training 158 trainer = Trainer( 159 accelerator="auto", 160 devices=[0], 161 max_epochs=max_epochs, 162 callbacks=[ 163 EarlyStopping(monitor="R_1", mode="max", patience=patience), 164 checkpoint_callback, 165 ], 166 default_root_dir=save_dir, 167 ) 168 169 trainer.fit( 170 train_model, 171 datamodule=data, 172 ) 173 174 # Transform weights to PyTorch format 175 model_state_dict = torch.load(checkpoint_callback.best_model_path) 176 model_state_dict = model_state_dict["state_dict"] 177 new_model_state_dict = dict() 178 for k in model_state_dict.keys(): 179 new_model_state_dict[k[6:]] = model_state_dict[k] 180 torch.save(new_model_state_dict, checkpoint_callback.best_model_path) 181 return checkpoint_callback.best_model_path
Fine-tunes the CosPlace model for given target database
Parameters
- target_db: The database for which the model will be fine-tuned
- valid_db: Validation database
- train_db: Training database
- save_dir: Directory for saving output model and log
- voxel_size: Voxel size for down sampling point clouds
Returns
Path to output model
class
NetVLAD:
49class NetVLAD: 50 """ 51 Implementation of [NetVLAD](https://github.com/QVPR/Patch-NetVLAD) global localization method. 52 """ 53 54 def __init__( 55 self, 56 path_to_weights: str, 57 resize: tuple[int, int] = (480, 640), 58 threads: int = 0, 59 batch_size: int = 20, 60 use_vladv2: bool = False, 61 ): 62 self.cuda = torch.cuda.is_available() 63 self.device = torch.device("cuda" if self.cuda else "cpu") 64 self.encoder_dim, self.encoder = get_backend() 65 66 self.resize = resize 67 self.threads = threads 68 self.batch_size = batch_size 69 self.use_vladv2 = use_vladv2 70 71 if isfile(path_to_weights): 72 self.path_to_weights = path_to_weights 73 else: 74 raise FileNotFoundError( 75 "=> no checkpoint found at '{}'".format(path_to_weights) 76 ) 77 78 self.checkpoint = torch.load( 79 self.path_to_weights, map_location=lambda storage, loc: storage 80 ) 81 self.num_clusters = self.checkpoint["state_dict"]["pool.centroids"].shape[0] 82 83 def get_database_descriptors( 84 self, 85 database: Database, 86 ): 87 """ 88 Gets database RGB images CosPlace descriptors 89 :param database: Database for getting descriptors 90 :return: Descriptors for database images 91 """ 92 num_pcs = self.checkpoint["state_dict"]["WPCA.0.bias"].shape[0] 93 94 model = get_model( 95 self.encoder, 96 self.encoder_dim, 97 self.num_clusters, 98 self.use_vladv2, 99 append_pca_layer=True, 100 num_pcs=num_pcs, 101 ) 102 model.load_state_dict(self.checkpoint["state_dict"]) 103 model = model.to(self.device) 104 105 color_images_paths = [img.path for img in database.color_images] 106 dataset = IDataset(color_images_paths, self.resize) 107 test_data_loader = DataLoader( 108 dataset=dataset, 109 num_workers=self.threads, 110 batch_size=self.batch_size, 111 shuffle=False, 112 pin_memory=self.cuda, 113 ) 114 model.eval() 115 with torch.no_grad(): 116 db_feat = np.empty((len(dataset), num_pcs), dtype=np.float32) 117 for iteration, (input_data, indices) in tqdm( 118 enumerate(test_data_loader), total=len(test_data_loader) 119 ): 120 indices_np = indices.detach().numpy() 121 input_data = input_data.to(self.device) 122 image_encoding = model.encoder(input_data) 123 vlad_global = model.pool(image_encoding) 124 vlad_global_pca = get_pca_encoding(model, vlad_global) 125 db_feat[indices_np, :] = vlad_global_pca.detach().cpu().numpy() 126 127 return db_feat 128 129 def fine_tune_model( 130 self, 131 target_db: Database, 132 valid_db: Database, 133 train_db: Database, 134 save_dir: str, 135 voxel_size: float = 0.3, 136 seed=42, 137 add_pca=True, 138 optim_name="SGD", 139 lr=0.0001, 140 momentum=0.9, 141 weight_decay=0.001, 142 lr_step=5, 143 lr_gamma=0.5, 144 margin=0.1, 145 nNeg=5, 146 cache_bs=20, 147 bs=4, 148 max_epochs=100, 149 eval_every=1, 150 patience=5, 151 n_features=10000, 152 num_pcs=4096, 153 ) -> str: 154 """ 155 Fine-tunes the NetVLAD model for given target database 156 :param target_db: The database for which the model will be fine-tuned 157 :param valid_db: Validation database 158 :param train_db: Training database 159 :param save_dir: Directory for saving output model and log 160 :param voxel_size: Voxel size for down sampling point clouds 161 :return: Path to output model 162 """ 163 min_bounds, max_bounds = find_bounds_for_multiple_databases( 164 [target_db, valid_db, train_db] 165 ) 166 167 voxel_grid = VoxelGrid(min_bounds, max_bounds, voxel_size) 168 train_targets = match_two_databases(train_db, target_db, voxel_grid) 169 val_targets = match_two_databases(valid_db, target_db, voxel_grid) 170 171 train_paths = [img.path for img in train_db.color_images] 172 val_paths = [img.path for img in valid_db.color_images] 173 db_paths = [img.path for img in target_db.color_images] 174 175 make_deterministic(seed=seed) 176 177 scheduler = None 178 179 checkpoint = torch.load( 180 self.path_to_weights, map_location=lambda storage, loc: storage 181 ) 182 # Deleting WPCA layer 183 del checkpoint["state_dict"]["WPCA.0.weight"] 184 del checkpoint["state_dict"]["WPCA.0.bias"] 185 186 model = get_model( 187 self.encoder, self.encoder_dim, self.num_clusters, self.use_vladv2 188 ) 189 model.load_state_dict(checkpoint["state_dict"]) 190 191 if optim_name == "ADAM": 192 optimizer = optim.Adam( 193 filter(lambda par: par.requires_grad, model.parameters()), lr=lr 194 ) 195 elif optim_name == "SGD": 196 optimizer = optim.SGD( 197 filter(lambda par: par.requires_grad, model.parameters()), 198 lr=lr, 199 momentum=momentum, 200 weight_decay=weight_decay, 201 ) 202 203 scheduler = optim.lr_scheduler.StepLR( 204 optimizer, step_size=lr_step, gamma=lr_gamma 205 ) 206 else: 207 raise ValueError("Unknown optimizer: " + optim_name) 208 209 criterion = nn.TripletMarginLoss( 210 margin=(margin**0.5), p=2, reduction="sum" 211 ).to(self.device) 212 213 model = model.to(self.device) 214 215 print("===> Loading dataset(s)") 216 train_dataset = TDataset( 217 db_paths, 218 train_paths, 219 train_targets, 220 n_neg=nNeg, 221 transform=input_transform(), 222 bs=cache_bs, 223 threads=self.threads, 224 ) 225 226 validation_dataset = TDataset( 227 db_paths, 228 val_paths, 229 val_targets, 230 n_neg=nNeg, 231 transform=input_transform(), 232 bs=cache_bs, 233 threads=self.threads, 234 ) 235 236 print("===> Training query set:", len(train_dataset.q_idx)) 237 print("===> Evaluating on val set, query count:", len(validation_dataset.q_idx)) 238 print("===> Training model") 239 240 writer = SummaryWriter(log_dir=save_dir) 241 242 logdir = writer.file_writer.get_logdir() 243 save_file_path = join(logdir, "checkpoints") 244 makedirs(save_file_path) 245 246 not_improved = 0 247 best_score = 0 248 for epoch in trange( 249 1, max_epochs + 1, desc="Epoch number".rjust(15), position=0 250 ): 251 train_epoch( 252 train_dataset, 253 model, 254 optimizer, 255 criterion, 256 self.encoder_dim, 257 self.device, 258 epoch, 259 writer, 260 bs, 261 self.num_clusters, 262 self.threads, 263 ) 264 if scheduler is not None: 265 scheduler.step(epoch) 266 if (epoch % eval_every) == 0: 267 recalls = validate( 268 validation_dataset, 269 model, 270 self.encoder_dim, 271 self.device, 272 writer, 273 self.threads, 274 cache_bs, 275 self.num_clusters, 276 epoch, 277 write_tboard=True, 278 pbar_position=1, 279 ) 280 is_best = recalls[1] > best_score 281 if is_best: 282 not_improved = 0 283 best_score = recalls[1] 284 else: 285 not_improved += 1 286 287 save_checkpoint( 288 { 289 "epoch": epoch, 290 "state_dict": model.state_dict(), 291 "recalls": recalls, 292 "best_score": best_score, 293 "not_improved": not_improved, 294 "optimizer": optimizer.state_dict(), 295 "parallel": False, 296 }, 297 is_best, 298 save_file_path, 299 ) 300 301 if patience > 0 and not_improved > (patience / int(eval_every)): 302 print( 303 "Performance did not improve for", patience, "epochs. Stopping." 304 ) 305 break 306 307 print("=> Best Recall@5: {:.4f}".format(best_score), flush=True) 308 writer.close() 309 save_path = join(save_file_path, "model_best.pth.tar") 310 print("Done") 311 312 if add_pca: 313 print("Adding PCA layer") 314 model = get_model( 315 self.encoder, 316 self.encoder_dim, 317 self.num_clusters, 318 append_pca_layer=False, 319 ) 320 model.load_state_dict(checkpoint["state_dict"]) 321 model = model.to(self.device) 322 323 pool_size = self.encoder_dim 324 pool_size *= self.num_clusters 325 326 print("===> Loading PCA dataset(s)") 327 328 if n_features > len(target_db): 329 n_features = len(target_db) 330 331 sampler = SubsetRandomSampler( 332 np.random.choice(len(target_db), n_features, replace=False) 333 ) 334 335 data_loader = DataLoader( 336 dataset=IDataset(db_paths), 337 num_workers=self.threads, 338 batch_size=cache_bs, 339 shuffle=False, 340 pin_memory=self.cuda, 341 sampler=sampler, 342 ) 343 344 print("===> Do inference to extract features and save them.") 345 346 model.eval() 347 with torch.no_grad(): 348 tqdm.write("====> Extracting Features") 349 350 db_feat = np.empty((len(data_loader.sampler), pool_size)) 351 print("Compute", len(db_feat), "features") 352 353 for iteration, (input_data, indices) in tqdm( 354 enumerate(data_loader), total=len(data_loader) 355 ): 356 input_data = input_data.to(self.device) 357 image_encoding = model.encoder(input_data) 358 vlad_encoding = model.pool(image_encoding) 359 out_vectors = vlad_encoding.detach().cpu().numpy() 360 # this allows for randomly shuffled inputs 361 for idx, out_vector in enumerate(out_vectors): 362 db_feat[ 363 iteration * data_loader.batch_size + idx, : 364 ] = out_vector 365 366 del input_data, image_encoding, vlad_encoding 367 368 print("===> Compute PCA, takes a while") 369 model_pca = pca(model, num_pcs, db_feat, pool_size) 370 371 save_path = save_path.replace(".pth.tar", "_WPCA.pth.tar") 372 373 torch.save({"state_dict": model_pca.state_dict()}, save_path) 374 375 print("Done") 376 377 return save_path
Implementation of NetVLAD global localization method.
NetVLAD( path_to_weights: str, resize: tuple[int, int] = (480, 640), threads: int = 0, batch_size: int = 20, use_vladv2: bool = False)
54 def __init__( 55 self, 56 path_to_weights: str, 57 resize: tuple[int, int] = (480, 640), 58 threads: int = 0, 59 batch_size: int = 20, 60 use_vladv2: bool = False, 61 ): 62 self.cuda = torch.cuda.is_available() 63 self.device = torch.device("cuda" if self.cuda else "cpu") 64 self.encoder_dim, self.encoder = get_backend() 65 66 self.resize = resize 67 self.threads = threads 68 self.batch_size = batch_size 69 self.use_vladv2 = use_vladv2 70 71 if isfile(path_to_weights): 72 self.path_to_weights = path_to_weights 73 else: 74 raise FileNotFoundError( 75 "=> no checkpoint found at '{}'".format(path_to_weights) 76 ) 77 78 self.checkpoint = torch.load( 79 self.path_to_weights, map_location=lambda storage, loc: storage 80 ) 81 self.num_clusters = self.checkpoint["state_dict"]["pool.centroids"].shape[0]
83 def get_database_descriptors( 84 self, 85 database: Database, 86 ): 87 """ 88 Gets database RGB images CosPlace descriptors 89 :param database: Database for getting descriptors 90 :return: Descriptors for database images 91 """ 92 num_pcs = self.checkpoint["state_dict"]["WPCA.0.bias"].shape[0] 93 94 model = get_model( 95 self.encoder, 96 self.encoder_dim, 97 self.num_clusters, 98 self.use_vladv2, 99 append_pca_layer=True, 100 num_pcs=num_pcs, 101 ) 102 model.load_state_dict(self.checkpoint["state_dict"]) 103 model = model.to(self.device) 104 105 color_images_paths = [img.path for img in database.color_images] 106 dataset = IDataset(color_images_paths, self.resize) 107 test_data_loader = DataLoader( 108 dataset=dataset, 109 num_workers=self.threads, 110 batch_size=self.batch_size, 111 shuffle=False, 112 pin_memory=self.cuda, 113 ) 114 model.eval() 115 with torch.no_grad(): 116 db_feat = np.empty((len(dataset), num_pcs), dtype=np.float32) 117 for iteration, (input_data, indices) in tqdm( 118 enumerate(test_data_loader), total=len(test_data_loader) 119 ): 120 indices_np = indices.detach().numpy() 121 input_data = input_data.to(self.device) 122 image_encoding = model.encoder(input_data) 123 vlad_global = model.pool(image_encoding) 124 vlad_global_pca = get_pca_encoding(model, vlad_global) 125 db_feat[indices_np, :] = vlad_global_pca.detach().cpu().numpy() 126 127 return db_feat
Gets database RGB images CosPlace descriptors
Parameters
- database: Database for getting descriptors
Returns
Descriptors for database images
def
fine_tune_model( self, target_db: vprdb.core.database.Database, valid_db: vprdb.core.database.Database, train_db: vprdb.core.database.Database, save_dir: str, voxel_size: float = 0.3, seed=42, add_pca=True, optim_name='SGD', lr=0.0001, momentum=0.9, weight_decay=0.001, lr_step=5, lr_gamma=0.5, margin=0.1, nNeg=5, cache_bs=20, bs=4, max_epochs=100, eval_every=1, patience=5, n_features=10000, num_pcs=4096) -> str:
129 def fine_tune_model( 130 self, 131 target_db: Database, 132 valid_db: Database, 133 train_db: Database, 134 save_dir: str, 135 voxel_size: float = 0.3, 136 seed=42, 137 add_pca=True, 138 optim_name="SGD", 139 lr=0.0001, 140 momentum=0.9, 141 weight_decay=0.001, 142 lr_step=5, 143 lr_gamma=0.5, 144 margin=0.1, 145 nNeg=5, 146 cache_bs=20, 147 bs=4, 148 max_epochs=100, 149 eval_every=1, 150 patience=5, 151 n_features=10000, 152 num_pcs=4096, 153 ) -> str: 154 """ 155 Fine-tunes the NetVLAD model for given target database 156 :param target_db: The database for which the model will be fine-tuned 157 :param valid_db: Validation database 158 :param train_db: Training database 159 :param save_dir: Directory for saving output model and log 160 :param voxel_size: Voxel size for down sampling point clouds 161 :return: Path to output model 162 """ 163 min_bounds, max_bounds = find_bounds_for_multiple_databases( 164 [target_db, valid_db, train_db] 165 ) 166 167 voxel_grid = VoxelGrid(min_bounds, max_bounds, voxel_size) 168 train_targets = match_two_databases(train_db, target_db, voxel_grid) 169 val_targets = match_two_databases(valid_db, target_db, voxel_grid) 170 171 train_paths = [img.path for img in train_db.color_images] 172 val_paths = [img.path for img in valid_db.color_images] 173 db_paths = [img.path for img in target_db.color_images] 174 175 make_deterministic(seed=seed) 176 177 scheduler = None 178 179 checkpoint = torch.load( 180 self.path_to_weights, map_location=lambda storage, loc: storage 181 ) 182 # Deleting WPCA layer 183 del checkpoint["state_dict"]["WPCA.0.weight"] 184 del checkpoint["state_dict"]["WPCA.0.bias"] 185 186 model = get_model( 187 self.encoder, self.encoder_dim, self.num_clusters, self.use_vladv2 188 ) 189 model.load_state_dict(checkpoint["state_dict"]) 190 191 if optim_name == "ADAM": 192 optimizer = optim.Adam( 193 filter(lambda par: par.requires_grad, model.parameters()), lr=lr 194 ) 195 elif optim_name == "SGD": 196 optimizer = optim.SGD( 197 filter(lambda par: par.requires_grad, model.parameters()), 198 lr=lr, 199 momentum=momentum, 200 weight_decay=weight_decay, 201 ) 202 203 scheduler = optim.lr_scheduler.StepLR( 204 optimizer, step_size=lr_step, gamma=lr_gamma 205 ) 206 else: 207 raise ValueError("Unknown optimizer: " + optim_name) 208 209 criterion = nn.TripletMarginLoss( 210 margin=(margin**0.5), p=2, reduction="sum" 211 ).to(self.device) 212 213 model = model.to(self.device) 214 215 print("===> Loading dataset(s)") 216 train_dataset = TDataset( 217 db_paths, 218 train_paths, 219 train_targets, 220 n_neg=nNeg, 221 transform=input_transform(), 222 bs=cache_bs, 223 threads=self.threads, 224 ) 225 226 validation_dataset = TDataset( 227 db_paths, 228 val_paths, 229 val_targets, 230 n_neg=nNeg, 231 transform=input_transform(), 232 bs=cache_bs, 233 threads=self.threads, 234 ) 235 236 print("===> Training query set:", len(train_dataset.q_idx)) 237 print("===> Evaluating on val set, query count:", len(validation_dataset.q_idx)) 238 print("===> Training model") 239 240 writer = SummaryWriter(log_dir=save_dir) 241 242 logdir = writer.file_writer.get_logdir() 243 save_file_path = join(logdir, "checkpoints") 244 makedirs(save_file_path) 245 246 not_improved = 0 247 best_score = 0 248 for epoch in trange( 249 1, max_epochs + 1, desc="Epoch number".rjust(15), position=0 250 ): 251 train_epoch( 252 train_dataset, 253 model, 254 optimizer, 255 criterion, 256 self.encoder_dim, 257 self.device, 258 epoch, 259 writer, 260 bs, 261 self.num_clusters, 262 self.threads, 263 ) 264 if scheduler is not None: 265 scheduler.step(epoch) 266 if (epoch % eval_every) == 0: 267 recalls = validate( 268 validation_dataset, 269 model, 270 self.encoder_dim, 271 self.device, 272 writer, 273 self.threads, 274 cache_bs, 275 self.num_clusters, 276 epoch, 277 write_tboard=True, 278 pbar_position=1, 279 ) 280 is_best = recalls[1] > best_score 281 if is_best: 282 not_improved = 0 283 best_score = recalls[1] 284 else: 285 not_improved += 1 286 287 save_checkpoint( 288 { 289 "epoch": epoch, 290 "state_dict": model.state_dict(), 291 "recalls": recalls, 292 "best_score": best_score, 293 "not_improved": not_improved, 294 "optimizer": optimizer.state_dict(), 295 "parallel": False, 296 }, 297 is_best, 298 save_file_path, 299 ) 300 301 if patience > 0 and not_improved > (patience / int(eval_every)): 302 print( 303 "Performance did not improve for", patience, "epochs. Stopping." 304 ) 305 break 306 307 print("=> Best Recall@5: {:.4f}".format(best_score), flush=True) 308 writer.close() 309 save_path = join(save_file_path, "model_best.pth.tar") 310 print("Done") 311 312 if add_pca: 313 print("Adding PCA layer") 314 model = get_model( 315 self.encoder, 316 self.encoder_dim, 317 self.num_clusters, 318 append_pca_layer=False, 319 ) 320 model.load_state_dict(checkpoint["state_dict"]) 321 model = model.to(self.device) 322 323 pool_size = self.encoder_dim 324 pool_size *= self.num_clusters 325 326 print("===> Loading PCA dataset(s)") 327 328 if n_features > len(target_db): 329 n_features = len(target_db) 330 331 sampler = SubsetRandomSampler( 332 np.random.choice(len(target_db), n_features, replace=False) 333 ) 334 335 data_loader = DataLoader( 336 dataset=IDataset(db_paths), 337 num_workers=self.threads, 338 batch_size=cache_bs, 339 shuffle=False, 340 pin_memory=self.cuda, 341 sampler=sampler, 342 ) 343 344 print("===> Do inference to extract features and save them.") 345 346 model.eval() 347 with torch.no_grad(): 348 tqdm.write("====> Extracting Features") 349 350 db_feat = np.empty((len(data_loader.sampler), pool_size)) 351 print("Compute", len(db_feat), "features") 352 353 for iteration, (input_data, indices) in tqdm( 354 enumerate(data_loader), total=len(data_loader) 355 ): 356 input_data = input_data.to(self.device) 357 image_encoding = model.encoder(input_data) 358 vlad_encoding = model.pool(image_encoding) 359 out_vectors = vlad_encoding.detach().cpu().numpy() 360 # this allows for randomly shuffled inputs 361 for idx, out_vector in enumerate(out_vectors): 362 db_feat[ 363 iteration * data_loader.batch_size + idx, : 364 ] = out_vector 365 366 del input_data, image_encoding, vlad_encoding 367 368 print("===> Compute PCA, takes a while") 369 model_pca = pca(model, num_pcs, db_feat, pool_size) 370 371 save_path = save_path.replace(".pth.tar", "_WPCA.pth.tar") 372 373 torch.save({"state_dict": model_pca.state_dict()}, save_path) 374 375 print("Done") 376 377 return save_path
Fine-tunes the NetVLAD model for given target database
Parameters
- target_db: The database for which the model will be fine-tuned
- valid_db: Validation database
- train_db: Training database
- save_dir: Directory for saving output model and log
- voxel_size: Voxel size for down sampling point clouds
Returns
Path to output model
class
SuperGlue:
56class SuperGlue: 57 """ 58 Implementation of [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork) 59 matcher with SuperPoint extractor. 60 """ 61 62 def __init__( 63 self, 64 path_to_sp_weights, 65 path_to_sg_weights, 66 resize=(640, 480), 67 resize_float=False, 68 ): 69 self.device = "cuda" if torch.cuda.is_available() else "cpu" 70 print('Running inference on device "{}"'.format(self.device)) 71 72 self.super_point = SuperPoint(path_to_sp_weights).eval().to(self.device) 73 self.super_glue_matcher = ( 74 SuperGlueMatcher(path_to_sg_weights).eval().to(self.device) 75 ) 76 self.resize = resize 77 self.resize_float = resize_float 78 79 def get_database_features(self, database: Database): 80 """ 81 Gets database RGB images SuperPoint features 82 :param database: Database for getting features 83 :return: Features for database images 84 """ 85 features = [] 86 for image in tqdm(database.color_images): 87 inp = read_image(image.path, self.device, self.resize, self.resize_float) 88 with torch.no_grad(): 89 features_for_query = self.super_point({"image": inp}) 90 features.append(features_for_query) 91 return np.asarray(features) 92 93 def match_feature(self, query_feature, db_features): 94 """ 95 Matches query feature with database features 96 :param query_feature: Feature for matching 97 :param db_features: Database features 98 :return: Index of matched image from database 99 """ 100 query_image_results = [] 101 for db_index, db_feature in enumerate(db_features): 102 pred = {k + "0": v for k, v in query_feature.items()} 103 pred = {**pred, **{k + "1": v for k, v in db_feature.items()}} 104 with torch.no_grad(): 105 pred = self.super_glue_matcher(pred, self.resize) 106 pred = {k: v[0].cpu().numpy() for k, v in pred.items()} 107 108 matches = pred["matches0"] 109 num_matches = np.sum(matches > -1) 110 query_image_results.append(num_matches) 111 return np.argmax(query_image_results)
Implementation of SuperGlue matcher with SuperPoint extractor.
SuperGlue( path_to_sp_weights, path_to_sg_weights, resize=(640, 480), resize_float=False)
62 def __init__( 63 self, 64 path_to_sp_weights, 65 path_to_sg_weights, 66 resize=(640, 480), 67 resize_float=False, 68 ): 69 self.device = "cuda" if torch.cuda.is_available() else "cpu" 70 print('Running inference on device "{}"'.format(self.device)) 71 72 self.super_point = SuperPoint(path_to_sp_weights).eval().to(self.device) 73 self.super_glue_matcher = ( 74 SuperGlueMatcher(path_to_sg_weights).eval().to(self.device) 75 ) 76 self.resize = resize 77 self.resize_float = resize_float
79 def get_database_features(self, database: Database): 80 """ 81 Gets database RGB images SuperPoint features 82 :param database: Database for getting features 83 :return: Features for database images 84 """ 85 features = [] 86 for image in tqdm(database.color_images): 87 inp = read_image(image.path, self.device, self.resize, self.resize_float) 88 with torch.no_grad(): 89 features_for_query = self.super_point({"image": inp}) 90 features.append(features_for_query) 91 return np.asarray(features)
Gets database RGB images SuperPoint features
Parameters
- database: Database for getting features
Returns
Features for database images
def
match_feature(self, query_feature, db_features):
93 def match_feature(self, query_feature, db_features): 94 """ 95 Matches query feature with database features 96 :param query_feature: Feature for matching 97 :param db_features: Database features 98 :return: Index of matched image from database 99 """ 100 query_image_results = [] 101 for db_index, db_feature in enumerate(db_features): 102 pred = {k + "0": v for k, v in query_feature.items()} 103 pred = {**pred, **{k + "1": v for k, v in db_feature.items()}} 104 with torch.no_grad(): 105 pred = self.super_glue_matcher(pred, self.resize) 106 pred = {k: v[0].cpu().numpy() for k, v in pred.items()} 107 108 matches = pred["matches0"] 109 num_matches = np.sum(matches > -1) 110 query_image_results.append(num_matches) 111 return np.argmax(query_image_results)
Matches query feature with database features
Parameters
- query_feature: Feature for matching
- db_features: Database features
Returns
Index of matched image from database