vprdb.vpr_systems

The vpr_systems contains a set of tools for the VPR task.

 1#  Copyright (c) 2023, Ivan Moskalenko, Anastasiia Kornilova
 2#
 3#  Licensed under the Apache License, Version 2.0 (the "License");
 4#  you may not use this file except in compliance with the License.
 5#  You may obtain a copy of the License at
 6#
 7#      http://www.apache.org/licenses/LICENSE-2.0
 8#
 9#  Unless required by applicable law or agreed to in writing, software
10#  distributed under the License is distributed on an "AS IS" BASIS,
11#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12#  See the License for the specific language governing permissions and
13#  limitations under the License.
14""" The `vpr_systems` contains a set of tools for the VPR task. """
15from vprdb.vpr_systems.cos_place import CosPlace
16from vprdb.vpr_systems.netvlad import NetVLAD
17from vprdb.vpr_systems.superglue import SuperGlue
18
19__all__ = ["CosPlace", "NetVLAD", "SuperGlue"]
class CosPlace:
 39class CosPlace:
 40    """
 41    Implementation of [CosPlace](https://github.com/gmberton/CosPlace) global localization method.
 42    """
 43
 44    def __init__(self, backbone: str, fc_output_dim: int, path_to_weights: str):
 45        self.backbone = backbone
 46        self.fc_output_dim = fc_output_dim
 47        self.path_to_weights = path_to_weights
 48
 49        self.model = GeoLocalizationNet(backbone, fc_output_dim)
 50        model_state_dict = torch.load(self.path_to_weights)
 51        self.model.load_state_dict(model_state_dict)
 52
 53        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 54        self.model.to(self.device)
 55
 56    def get_database_descriptors(self, database: Database):
 57        """
 58        Gets database RGB images CosPlace descriptors
 59        :param database: Database for getting descriptors
 60        :return: Descriptors for database images
 61        """
 62        self.model.eval()
 63        image_providers = database.color_images
 64        with torch.no_grad():
 65            all_descriptors = np.empty(
 66                (len(image_providers), self.fc_output_dim), dtype="float32"
 67            )
 68            for i, image in tqdm(
 69                enumerate(image_providers), total=len(image_providers)
 70            ):
 71                image_bgr = image.color_image
 72                image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
 73                base_transform = torchvision.transforms.Compose(
 74                    [
 75                        torchvision.transforms.ToTensor(),
 76                        torchvision.transforms.Normalize(
 77                            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
 78                        ),
 79                    ]
 80                )
 81                normalized_img = base_transform(image_rgb)
 82                normalized_img = normalized_img[None, :]
 83                descriptor = self.model(normalized_img.to(self.device))
 84                descriptor = descriptor.cpu().numpy()
 85                all_descriptors[i] = descriptor
 86        return all_descriptors
 87
 88    def fine_tune_model(
 89        self,
 90        target_db: Database,
 91        valid_db: Database,
 92        train_db: Database,
 93        save_dir: str,
 94        voxel_size=0.3,
 95        random_resize=(480, 640),
 96        brightness=0.7,
 97        contrast=0.7,
 98        saturation=0.7,
 99        hue=0.5,
100        random_resized_crop=0.5,
101        num_workers=0,
102        batch_size=8,
103        lr=0.00001,
104        classifiers_lr=0.01,
105        patience=10,
106        max_epochs=-1,
107        seed=0,
108    ) -> str:
109        """
110        Fine-tunes the CosPlace model for given target database
111        :param target_db: The database for which the model will be fine-tuned
112        :param valid_db: Validation database
113        :param train_db: Training database
114        :param save_dir: Directory for saving output model and log
115        :param voxel_size: Voxel size for down sampling point clouds
116        :return: Path to output model
117        """
118        min_bounds, max_bounds = find_bounds_for_multiple_databases(
119            [target_db, valid_db, train_db]
120        )
121        voxel_grid = VoxelGrid(min_bounds, max_bounds, voxel_size)
122
123        groups = create_groups(train_db, target_db, voxel_grid)
124        groups_lens = [len(group_db) for group_db, _ in groups]
125        val_matches = match_two_databases(valid_db, target_db, voxel_grid)
126
127        make_deterministic(seed=seed)
128
129        data = DataModule(
130            groups,
131            target_db,
132            valid_db,
133            val_matches,
134            random_resize,
135            brightness,
136            contrast,
137            saturation,
138            hue,
139            random_resized_crop,
140            num_workers,
141            batch_size,
142        )
143        train_model = CosPlaceTrainer(
144            self.model,
145            self.fc_output_dim,
146            groups_lens,
147            len(target_db),
148            len(valid_db),
149            lr,
150            classifiers_lr,
151            self.device,
152        )
153
154        checkpoint_callback = ModelCheckpoint(
155            dirpath=save_dir, filename="best_model", save_weights_only=True
156        )
157        # start training
158        trainer = Trainer(
159            accelerator="auto",
160            devices=[0],
161            max_epochs=max_epochs,
162            callbacks=[
163                EarlyStopping(monitor="R_1", mode="max", patience=patience),
164                checkpoint_callback,
165            ],
166            default_root_dir=save_dir,
167        )
168
169        trainer.fit(
170            train_model,
171            datamodule=data,
172        )
173
174        # Transform weights to PyTorch format
175        model_state_dict = torch.load(checkpoint_callback.best_model_path)
176        model_state_dict = model_state_dict["state_dict"]
177        new_model_state_dict = dict()
178        for k in model_state_dict.keys():
179            new_model_state_dict[k[6:]] = model_state_dict[k]
180        torch.save(new_model_state_dict, checkpoint_callback.best_model_path)
181        return checkpoint_callback.best_model_path

Implementation of CosPlace global localization method.

CosPlace(backbone: str, fc_output_dim: int, path_to_weights: str)
44    def __init__(self, backbone: str, fc_output_dim: int, path_to_weights: str):
45        self.backbone = backbone
46        self.fc_output_dim = fc_output_dim
47        self.path_to_weights = path_to_weights
48
49        self.model = GeoLocalizationNet(backbone, fc_output_dim)
50        model_state_dict = torch.load(self.path_to_weights)
51        self.model.load_state_dict(model_state_dict)
52
53        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54        self.model.to(self.device)
backbone
fc_output_dim
path_to_weights
model
device
def get_database_descriptors(self, database: vprdb.core.database.Database):
56    def get_database_descriptors(self, database: Database):
57        """
58        Gets database RGB images CosPlace descriptors
59        :param database: Database for getting descriptors
60        :return: Descriptors for database images
61        """
62        self.model.eval()
63        image_providers = database.color_images
64        with torch.no_grad():
65            all_descriptors = np.empty(
66                (len(image_providers), self.fc_output_dim), dtype="float32"
67            )
68            for i, image in tqdm(
69                enumerate(image_providers), total=len(image_providers)
70            ):
71                image_bgr = image.color_image
72                image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
73                base_transform = torchvision.transforms.Compose(
74                    [
75                        torchvision.transforms.ToTensor(),
76                        torchvision.transforms.Normalize(
77                            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
78                        ),
79                    ]
80                )
81                normalized_img = base_transform(image_rgb)
82                normalized_img = normalized_img[None, :]
83                descriptor = self.model(normalized_img.to(self.device))
84                descriptor = descriptor.cpu().numpy()
85                all_descriptors[i] = descriptor
86        return all_descriptors

Gets database RGB images CosPlace descriptors

Parameters
  • database: Database for getting descriptors
Returns

Descriptors for database images

def fine_tune_model( self, target_db: vprdb.core.database.Database, valid_db: vprdb.core.database.Database, train_db: vprdb.core.database.Database, save_dir: str, voxel_size=0.3, random_resize=(480, 640), brightness=0.7, contrast=0.7, saturation=0.7, hue=0.5, random_resized_crop=0.5, num_workers=0, batch_size=8, lr=1e-05, classifiers_lr=0.01, patience=10, max_epochs=-1, seed=0) -> str:
 88    def fine_tune_model(
 89        self,
 90        target_db: Database,
 91        valid_db: Database,
 92        train_db: Database,
 93        save_dir: str,
 94        voxel_size=0.3,
 95        random_resize=(480, 640),
 96        brightness=0.7,
 97        contrast=0.7,
 98        saturation=0.7,
 99        hue=0.5,
100        random_resized_crop=0.5,
101        num_workers=0,
102        batch_size=8,
103        lr=0.00001,
104        classifiers_lr=0.01,
105        patience=10,
106        max_epochs=-1,
107        seed=0,
108    ) -> str:
109        """
110        Fine-tunes the CosPlace model for given target database
111        :param target_db: The database for which the model will be fine-tuned
112        :param valid_db: Validation database
113        :param train_db: Training database
114        :param save_dir: Directory for saving output model and log
115        :param voxel_size: Voxel size for down sampling point clouds
116        :return: Path to output model
117        """
118        min_bounds, max_bounds = find_bounds_for_multiple_databases(
119            [target_db, valid_db, train_db]
120        )
121        voxel_grid = VoxelGrid(min_bounds, max_bounds, voxel_size)
122
123        groups = create_groups(train_db, target_db, voxel_grid)
124        groups_lens = [len(group_db) for group_db, _ in groups]
125        val_matches = match_two_databases(valid_db, target_db, voxel_grid)
126
127        make_deterministic(seed=seed)
128
129        data = DataModule(
130            groups,
131            target_db,
132            valid_db,
133            val_matches,
134            random_resize,
135            brightness,
136            contrast,
137            saturation,
138            hue,
139            random_resized_crop,
140            num_workers,
141            batch_size,
142        )
143        train_model = CosPlaceTrainer(
144            self.model,
145            self.fc_output_dim,
146            groups_lens,
147            len(target_db),
148            len(valid_db),
149            lr,
150            classifiers_lr,
151            self.device,
152        )
153
154        checkpoint_callback = ModelCheckpoint(
155            dirpath=save_dir, filename="best_model", save_weights_only=True
156        )
157        # start training
158        trainer = Trainer(
159            accelerator="auto",
160            devices=[0],
161            max_epochs=max_epochs,
162            callbacks=[
163                EarlyStopping(monitor="R_1", mode="max", patience=patience),
164                checkpoint_callback,
165            ],
166            default_root_dir=save_dir,
167        )
168
169        trainer.fit(
170            train_model,
171            datamodule=data,
172        )
173
174        # Transform weights to PyTorch format
175        model_state_dict = torch.load(checkpoint_callback.best_model_path)
176        model_state_dict = model_state_dict["state_dict"]
177        new_model_state_dict = dict()
178        for k in model_state_dict.keys():
179            new_model_state_dict[k[6:]] = model_state_dict[k]
180        torch.save(new_model_state_dict, checkpoint_callback.best_model_path)
181        return checkpoint_callback.best_model_path

Fine-tunes the CosPlace model for given target database

Parameters
  • target_db: The database for which the model will be fine-tuned
  • valid_db: Validation database
  • train_db: Training database
  • save_dir: Directory for saving output model and log
  • voxel_size: Voxel size for down sampling point clouds
Returns

Path to output model

class NetVLAD:
 49class NetVLAD:
 50    """
 51    Implementation of [NetVLAD](https://github.com/QVPR/Patch-NetVLAD) global localization method.
 52    """
 53
 54    def __init__(
 55        self,
 56        path_to_weights: str,
 57        resize: tuple[int, int] = (480, 640),
 58        threads: int = 0,
 59        batch_size: int = 20,
 60        use_vladv2: bool = False,
 61    ):
 62        self.cuda = torch.cuda.is_available()
 63        self.device = torch.device("cuda" if self.cuda else "cpu")
 64        self.encoder_dim, self.encoder = get_backend()
 65
 66        self.resize = resize
 67        self.threads = threads
 68        self.batch_size = batch_size
 69        self.use_vladv2 = use_vladv2
 70
 71        if isfile(path_to_weights):
 72            self.path_to_weights = path_to_weights
 73        else:
 74            raise FileNotFoundError(
 75                "=> no checkpoint found at '{}'".format(path_to_weights)
 76            )
 77
 78        self.checkpoint = torch.load(
 79            self.path_to_weights, map_location=lambda storage, loc: storage
 80        )
 81        self.num_clusters = self.checkpoint["state_dict"]["pool.centroids"].shape[0]
 82
 83    def get_database_descriptors(
 84        self,
 85        database: Database,
 86    ):
 87        """
 88        Gets database RGB images CosPlace descriptors
 89        :param database: Database for getting descriptors
 90        :return: Descriptors for database images
 91        """
 92        num_pcs = self.checkpoint["state_dict"]["WPCA.0.bias"].shape[0]
 93
 94        model = get_model(
 95            self.encoder,
 96            self.encoder_dim,
 97            self.num_clusters,
 98            self.use_vladv2,
 99            append_pca_layer=True,
100            num_pcs=num_pcs,
101        )
102        model.load_state_dict(self.checkpoint["state_dict"])
103        model = model.to(self.device)
104
105        color_images_paths = [img.path for img in database.color_images]
106        dataset = IDataset(color_images_paths, self.resize)
107        test_data_loader = DataLoader(
108            dataset=dataset,
109            num_workers=self.threads,
110            batch_size=self.batch_size,
111            shuffle=False,
112            pin_memory=self.cuda,
113        )
114        model.eval()
115        with torch.no_grad():
116            db_feat = np.empty((len(dataset), num_pcs), dtype=np.float32)
117            for iteration, (input_data, indices) in tqdm(
118                enumerate(test_data_loader), total=len(test_data_loader)
119            ):
120                indices_np = indices.detach().numpy()
121                input_data = input_data.to(self.device)
122                image_encoding = model.encoder(input_data)
123                vlad_global = model.pool(image_encoding)
124                vlad_global_pca = get_pca_encoding(model, vlad_global)
125                db_feat[indices_np, :] = vlad_global_pca.detach().cpu().numpy()
126
127        return db_feat
128
129    def fine_tune_model(
130        self,
131        target_db: Database,
132        valid_db: Database,
133        train_db: Database,
134        save_dir: str,
135        voxel_size: float = 0.3,
136        seed=42,
137        add_pca=True,
138        optim_name="SGD",
139        lr=0.0001,
140        momentum=0.9,
141        weight_decay=0.001,
142        lr_step=5,
143        lr_gamma=0.5,
144        margin=0.1,
145        nNeg=5,
146        cache_bs=20,
147        bs=4,
148        max_epochs=100,
149        eval_every=1,
150        patience=5,
151        n_features=10000,
152        num_pcs=4096,
153    ) -> str:
154        """
155        Fine-tunes the NetVLAD model for given target database
156        :param target_db: The database for which the model will be fine-tuned
157        :param valid_db: Validation database
158        :param train_db: Training database
159        :param save_dir: Directory for saving output model and log
160        :param voxel_size: Voxel size for down sampling point clouds
161        :return: Path to output model
162        """
163        min_bounds, max_bounds = find_bounds_for_multiple_databases(
164            [target_db, valid_db, train_db]
165        )
166
167        voxel_grid = VoxelGrid(min_bounds, max_bounds, voxel_size)
168        train_targets = match_two_databases(train_db, target_db, voxel_grid)
169        val_targets = match_two_databases(valid_db, target_db, voxel_grid)
170
171        train_paths = [img.path for img in train_db.color_images]
172        val_paths = [img.path for img in valid_db.color_images]
173        db_paths = [img.path for img in target_db.color_images]
174
175        make_deterministic(seed=seed)
176
177        scheduler = None
178
179        checkpoint = torch.load(
180            self.path_to_weights, map_location=lambda storage, loc: storage
181        )
182        # Deleting WPCA layer
183        del checkpoint["state_dict"]["WPCA.0.weight"]
184        del checkpoint["state_dict"]["WPCA.0.bias"]
185
186        model = get_model(
187            self.encoder, self.encoder_dim, self.num_clusters, self.use_vladv2
188        )
189        model.load_state_dict(checkpoint["state_dict"])
190
191        if optim_name == "ADAM":
192            optimizer = optim.Adam(
193                filter(lambda par: par.requires_grad, model.parameters()), lr=lr
194            )
195        elif optim_name == "SGD":
196            optimizer = optim.SGD(
197                filter(lambda par: par.requires_grad, model.parameters()),
198                lr=lr,
199                momentum=momentum,
200                weight_decay=weight_decay,
201            )
202
203            scheduler = optim.lr_scheduler.StepLR(
204                optimizer, step_size=lr_step, gamma=lr_gamma
205            )
206        else:
207            raise ValueError("Unknown optimizer: " + optim_name)
208
209        criterion = nn.TripletMarginLoss(
210            margin=(margin**0.5), p=2, reduction="sum"
211        ).to(self.device)
212
213        model = model.to(self.device)
214
215        print("===> Loading dataset(s)")
216        train_dataset = TDataset(
217            db_paths,
218            train_paths,
219            train_targets,
220            n_neg=nNeg,
221            transform=input_transform(),
222            bs=cache_bs,
223            threads=self.threads,
224        )
225
226        validation_dataset = TDataset(
227            db_paths,
228            val_paths,
229            val_targets,
230            n_neg=nNeg,
231            transform=input_transform(),
232            bs=cache_bs,
233            threads=self.threads,
234        )
235
236        print("===> Training query set:", len(train_dataset.q_idx))
237        print("===> Evaluating on val set, query count:", len(validation_dataset.q_idx))
238        print("===> Training model")
239
240        writer = SummaryWriter(log_dir=save_dir)
241
242        logdir = writer.file_writer.get_logdir()
243        save_file_path = join(logdir, "checkpoints")
244        makedirs(save_file_path)
245
246        not_improved = 0
247        best_score = 0
248        for epoch in trange(
249            1, max_epochs + 1, desc="Epoch number".rjust(15), position=0
250        ):
251            train_epoch(
252                train_dataset,
253                model,
254                optimizer,
255                criterion,
256                self.encoder_dim,
257                self.device,
258                epoch,
259                writer,
260                bs,
261                self.num_clusters,
262                self.threads,
263            )
264            if scheduler is not None:
265                scheduler.step(epoch)
266            if (epoch % eval_every) == 0:
267                recalls = validate(
268                    validation_dataset,
269                    model,
270                    self.encoder_dim,
271                    self.device,
272                    writer,
273                    self.threads,
274                    cache_bs,
275                    self.num_clusters,
276                    epoch,
277                    write_tboard=True,
278                    pbar_position=1,
279                )
280                is_best = recalls[1] > best_score
281                if is_best:
282                    not_improved = 0
283                    best_score = recalls[1]
284                else:
285                    not_improved += 1
286
287                save_checkpoint(
288                    {
289                        "epoch": epoch,
290                        "state_dict": model.state_dict(),
291                        "recalls": recalls,
292                        "best_score": best_score,
293                        "not_improved": not_improved,
294                        "optimizer": optimizer.state_dict(),
295                        "parallel": False,
296                    },
297                    is_best,
298                    save_file_path,
299                )
300
301                if patience > 0 and not_improved > (patience / int(eval_every)):
302                    print(
303                        "Performance did not improve for", patience, "epochs. Stopping."
304                    )
305                    break
306
307        print("=> Best Recall@5: {:.4f}".format(best_score), flush=True)
308        writer.close()
309        save_path = join(save_file_path, "model_best.pth.tar")
310        print("Done")
311
312        if add_pca:
313            print("Adding PCA layer")
314            model = get_model(
315                self.encoder,
316                self.encoder_dim,
317                self.num_clusters,
318                append_pca_layer=False,
319            )
320            model.load_state_dict(checkpoint["state_dict"])
321            model = model.to(self.device)
322
323            pool_size = self.encoder_dim
324            pool_size *= self.num_clusters
325
326            print("===> Loading PCA dataset(s)")
327
328            if n_features > len(target_db):
329                n_features = len(target_db)
330
331            sampler = SubsetRandomSampler(
332                np.random.choice(len(target_db), n_features, replace=False)
333            )
334
335            data_loader = DataLoader(
336                dataset=IDataset(db_paths),
337                num_workers=self.threads,
338                batch_size=cache_bs,
339                shuffle=False,
340                pin_memory=self.cuda,
341                sampler=sampler,
342            )
343
344            print("===> Do inference to extract features and save them.")
345
346            model.eval()
347            with torch.no_grad():
348                tqdm.write("====> Extracting Features")
349
350                db_feat = np.empty((len(data_loader.sampler), pool_size))
351                print("Compute", len(db_feat), "features")
352
353                for iteration, (input_data, indices) in tqdm(
354                    enumerate(data_loader), total=len(data_loader)
355                ):
356                    input_data = input_data.to(self.device)
357                    image_encoding = model.encoder(input_data)
358                    vlad_encoding = model.pool(image_encoding)
359                    out_vectors = vlad_encoding.detach().cpu().numpy()
360                    # this allows for randomly shuffled inputs
361                    for idx, out_vector in enumerate(out_vectors):
362                        db_feat[
363                            iteration * data_loader.batch_size + idx, :
364                        ] = out_vector
365
366                    del input_data, image_encoding, vlad_encoding
367
368            print("===> Compute PCA, takes a while")
369            model_pca = pca(model, num_pcs, db_feat, pool_size)
370
371            save_path = save_path.replace(".pth.tar", "_WPCA.pth.tar")
372
373            torch.save({"state_dict": model_pca.state_dict()}, save_path)
374
375            print("Done")
376
377        return save_path

Implementation of NetVLAD global localization method.

NetVLAD( path_to_weights: str, resize: tuple[int, int] = (480, 640), threads: int = 0, batch_size: int = 20, use_vladv2: bool = False)
54    def __init__(
55        self,
56        path_to_weights: str,
57        resize: tuple[int, int] = (480, 640),
58        threads: int = 0,
59        batch_size: int = 20,
60        use_vladv2: bool = False,
61    ):
62        self.cuda = torch.cuda.is_available()
63        self.device = torch.device("cuda" if self.cuda else "cpu")
64        self.encoder_dim, self.encoder = get_backend()
65
66        self.resize = resize
67        self.threads = threads
68        self.batch_size = batch_size
69        self.use_vladv2 = use_vladv2
70
71        if isfile(path_to_weights):
72            self.path_to_weights = path_to_weights
73        else:
74            raise FileNotFoundError(
75                "=> no checkpoint found at '{}'".format(path_to_weights)
76            )
77
78        self.checkpoint = torch.load(
79            self.path_to_weights, map_location=lambda storage, loc: storage
80        )
81        self.num_clusters = self.checkpoint["state_dict"]["pool.centroids"].shape[0]
cuda
device
resize
threads
batch_size
use_vladv2
checkpoint
num_clusters
def get_database_descriptors(self, database: vprdb.core.database.Database):
 83    def get_database_descriptors(
 84        self,
 85        database: Database,
 86    ):
 87        """
 88        Gets database RGB images CosPlace descriptors
 89        :param database: Database for getting descriptors
 90        :return: Descriptors for database images
 91        """
 92        num_pcs = self.checkpoint["state_dict"]["WPCA.0.bias"].shape[0]
 93
 94        model = get_model(
 95            self.encoder,
 96            self.encoder_dim,
 97            self.num_clusters,
 98            self.use_vladv2,
 99            append_pca_layer=True,
100            num_pcs=num_pcs,
101        )
102        model.load_state_dict(self.checkpoint["state_dict"])
103        model = model.to(self.device)
104
105        color_images_paths = [img.path for img in database.color_images]
106        dataset = IDataset(color_images_paths, self.resize)
107        test_data_loader = DataLoader(
108            dataset=dataset,
109            num_workers=self.threads,
110            batch_size=self.batch_size,
111            shuffle=False,
112            pin_memory=self.cuda,
113        )
114        model.eval()
115        with torch.no_grad():
116            db_feat = np.empty((len(dataset), num_pcs), dtype=np.float32)
117            for iteration, (input_data, indices) in tqdm(
118                enumerate(test_data_loader), total=len(test_data_loader)
119            ):
120                indices_np = indices.detach().numpy()
121                input_data = input_data.to(self.device)
122                image_encoding = model.encoder(input_data)
123                vlad_global = model.pool(image_encoding)
124                vlad_global_pca = get_pca_encoding(model, vlad_global)
125                db_feat[indices_np, :] = vlad_global_pca.detach().cpu().numpy()
126
127        return db_feat

Gets database RGB images CosPlace descriptors

Parameters
  • database: Database for getting descriptors
Returns

Descriptors for database images

def fine_tune_model( self, target_db: vprdb.core.database.Database, valid_db: vprdb.core.database.Database, train_db: vprdb.core.database.Database, save_dir: str, voxel_size: float = 0.3, seed=42, add_pca=True, optim_name='SGD', lr=0.0001, momentum=0.9, weight_decay=0.001, lr_step=5, lr_gamma=0.5, margin=0.1, nNeg=5, cache_bs=20, bs=4, max_epochs=100, eval_every=1, patience=5, n_features=10000, num_pcs=4096) -> str:
129    def fine_tune_model(
130        self,
131        target_db: Database,
132        valid_db: Database,
133        train_db: Database,
134        save_dir: str,
135        voxel_size: float = 0.3,
136        seed=42,
137        add_pca=True,
138        optim_name="SGD",
139        lr=0.0001,
140        momentum=0.9,
141        weight_decay=0.001,
142        lr_step=5,
143        lr_gamma=0.5,
144        margin=0.1,
145        nNeg=5,
146        cache_bs=20,
147        bs=4,
148        max_epochs=100,
149        eval_every=1,
150        patience=5,
151        n_features=10000,
152        num_pcs=4096,
153    ) -> str:
154        """
155        Fine-tunes the NetVLAD model for given target database
156        :param target_db: The database for which the model will be fine-tuned
157        :param valid_db: Validation database
158        :param train_db: Training database
159        :param save_dir: Directory for saving output model and log
160        :param voxel_size: Voxel size for down sampling point clouds
161        :return: Path to output model
162        """
163        min_bounds, max_bounds = find_bounds_for_multiple_databases(
164            [target_db, valid_db, train_db]
165        )
166
167        voxel_grid = VoxelGrid(min_bounds, max_bounds, voxel_size)
168        train_targets = match_two_databases(train_db, target_db, voxel_grid)
169        val_targets = match_two_databases(valid_db, target_db, voxel_grid)
170
171        train_paths = [img.path for img in train_db.color_images]
172        val_paths = [img.path for img in valid_db.color_images]
173        db_paths = [img.path for img in target_db.color_images]
174
175        make_deterministic(seed=seed)
176
177        scheduler = None
178
179        checkpoint = torch.load(
180            self.path_to_weights, map_location=lambda storage, loc: storage
181        )
182        # Deleting WPCA layer
183        del checkpoint["state_dict"]["WPCA.0.weight"]
184        del checkpoint["state_dict"]["WPCA.0.bias"]
185
186        model = get_model(
187            self.encoder, self.encoder_dim, self.num_clusters, self.use_vladv2
188        )
189        model.load_state_dict(checkpoint["state_dict"])
190
191        if optim_name == "ADAM":
192            optimizer = optim.Adam(
193                filter(lambda par: par.requires_grad, model.parameters()), lr=lr
194            )
195        elif optim_name == "SGD":
196            optimizer = optim.SGD(
197                filter(lambda par: par.requires_grad, model.parameters()),
198                lr=lr,
199                momentum=momentum,
200                weight_decay=weight_decay,
201            )
202
203            scheduler = optim.lr_scheduler.StepLR(
204                optimizer, step_size=lr_step, gamma=lr_gamma
205            )
206        else:
207            raise ValueError("Unknown optimizer: " + optim_name)
208
209        criterion = nn.TripletMarginLoss(
210            margin=(margin**0.5), p=2, reduction="sum"
211        ).to(self.device)
212
213        model = model.to(self.device)
214
215        print("===> Loading dataset(s)")
216        train_dataset = TDataset(
217            db_paths,
218            train_paths,
219            train_targets,
220            n_neg=nNeg,
221            transform=input_transform(),
222            bs=cache_bs,
223            threads=self.threads,
224        )
225
226        validation_dataset = TDataset(
227            db_paths,
228            val_paths,
229            val_targets,
230            n_neg=nNeg,
231            transform=input_transform(),
232            bs=cache_bs,
233            threads=self.threads,
234        )
235
236        print("===> Training query set:", len(train_dataset.q_idx))
237        print("===> Evaluating on val set, query count:", len(validation_dataset.q_idx))
238        print("===> Training model")
239
240        writer = SummaryWriter(log_dir=save_dir)
241
242        logdir = writer.file_writer.get_logdir()
243        save_file_path = join(logdir, "checkpoints")
244        makedirs(save_file_path)
245
246        not_improved = 0
247        best_score = 0
248        for epoch in trange(
249            1, max_epochs + 1, desc="Epoch number".rjust(15), position=0
250        ):
251            train_epoch(
252                train_dataset,
253                model,
254                optimizer,
255                criterion,
256                self.encoder_dim,
257                self.device,
258                epoch,
259                writer,
260                bs,
261                self.num_clusters,
262                self.threads,
263            )
264            if scheduler is not None:
265                scheduler.step(epoch)
266            if (epoch % eval_every) == 0:
267                recalls = validate(
268                    validation_dataset,
269                    model,
270                    self.encoder_dim,
271                    self.device,
272                    writer,
273                    self.threads,
274                    cache_bs,
275                    self.num_clusters,
276                    epoch,
277                    write_tboard=True,
278                    pbar_position=1,
279                )
280                is_best = recalls[1] > best_score
281                if is_best:
282                    not_improved = 0
283                    best_score = recalls[1]
284                else:
285                    not_improved += 1
286
287                save_checkpoint(
288                    {
289                        "epoch": epoch,
290                        "state_dict": model.state_dict(),
291                        "recalls": recalls,
292                        "best_score": best_score,
293                        "not_improved": not_improved,
294                        "optimizer": optimizer.state_dict(),
295                        "parallel": False,
296                    },
297                    is_best,
298                    save_file_path,
299                )
300
301                if patience > 0 and not_improved > (patience / int(eval_every)):
302                    print(
303                        "Performance did not improve for", patience, "epochs. Stopping."
304                    )
305                    break
306
307        print("=> Best Recall@5: {:.4f}".format(best_score), flush=True)
308        writer.close()
309        save_path = join(save_file_path, "model_best.pth.tar")
310        print("Done")
311
312        if add_pca:
313            print("Adding PCA layer")
314            model = get_model(
315                self.encoder,
316                self.encoder_dim,
317                self.num_clusters,
318                append_pca_layer=False,
319            )
320            model.load_state_dict(checkpoint["state_dict"])
321            model = model.to(self.device)
322
323            pool_size = self.encoder_dim
324            pool_size *= self.num_clusters
325
326            print("===> Loading PCA dataset(s)")
327
328            if n_features > len(target_db):
329                n_features = len(target_db)
330
331            sampler = SubsetRandomSampler(
332                np.random.choice(len(target_db), n_features, replace=False)
333            )
334
335            data_loader = DataLoader(
336                dataset=IDataset(db_paths),
337                num_workers=self.threads,
338                batch_size=cache_bs,
339                shuffle=False,
340                pin_memory=self.cuda,
341                sampler=sampler,
342            )
343
344            print("===> Do inference to extract features and save them.")
345
346            model.eval()
347            with torch.no_grad():
348                tqdm.write("====> Extracting Features")
349
350                db_feat = np.empty((len(data_loader.sampler), pool_size))
351                print("Compute", len(db_feat), "features")
352
353                for iteration, (input_data, indices) in tqdm(
354                    enumerate(data_loader), total=len(data_loader)
355                ):
356                    input_data = input_data.to(self.device)
357                    image_encoding = model.encoder(input_data)
358                    vlad_encoding = model.pool(image_encoding)
359                    out_vectors = vlad_encoding.detach().cpu().numpy()
360                    # this allows for randomly shuffled inputs
361                    for idx, out_vector in enumerate(out_vectors):
362                        db_feat[
363                            iteration * data_loader.batch_size + idx, :
364                        ] = out_vector
365
366                    del input_data, image_encoding, vlad_encoding
367
368            print("===> Compute PCA, takes a while")
369            model_pca = pca(model, num_pcs, db_feat, pool_size)
370
371            save_path = save_path.replace(".pth.tar", "_WPCA.pth.tar")
372
373            torch.save({"state_dict": model_pca.state_dict()}, save_path)
374
375            print("Done")
376
377        return save_path

Fine-tunes the NetVLAD model for given target database

Parameters
  • target_db: The database for which the model will be fine-tuned
  • valid_db: Validation database
  • train_db: Training database
  • save_dir: Directory for saving output model and log
  • voxel_size: Voxel size for down sampling point clouds
Returns

Path to output model

class SuperGlue:
 56class SuperGlue:
 57    """
 58    Implementation of [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork)
 59    matcher with SuperPoint extractor.
 60    """
 61
 62    def __init__(
 63        self,
 64        path_to_sp_weights,
 65        path_to_sg_weights,
 66        resize=(640, 480),
 67        resize_float=False,
 68    ):
 69        self.device = "cuda" if torch.cuda.is_available() else "cpu"
 70        print('Running inference on device "{}"'.format(self.device))
 71
 72        self.super_point = SuperPoint(path_to_sp_weights).eval().to(self.device)
 73        self.super_glue_matcher = (
 74            SuperGlueMatcher(path_to_sg_weights).eval().to(self.device)
 75        )
 76        self.resize = resize
 77        self.resize_float = resize_float
 78
 79    def get_database_features(self, database: Database):
 80        """
 81        Gets database RGB images SuperPoint features
 82        :param database: Database for getting features
 83        :return: Features for database images
 84        """
 85        features = []
 86        for image in tqdm(database.color_images):
 87            inp = read_image(image.path, self.device, self.resize, self.resize_float)
 88            with torch.no_grad():
 89                features_for_query = self.super_point({"image": inp})
 90            features.append(features_for_query)
 91        return np.asarray(features)
 92
 93    def match_feature(self, query_feature, db_features):
 94        """
 95        Matches query feature with database features
 96        :param query_feature: Feature for matching
 97        :param db_features: Database features
 98        :return: Index of matched image from database
 99        """
100        query_image_results = []
101        for db_index, db_feature in enumerate(db_features):
102            pred = {k + "0": v for k, v in query_feature.items()}
103            pred = {**pred, **{k + "1": v for k, v in db_feature.items()}}
104            with torch.no_grad():
105                pred = self.super_glue_matcher(pred, self.resize)
106            pred = {k: v[0].cpu().numpy() for k, v in pred.items()}
107
108            matches = pred["matches0"]
109            num_matches = np.sum(matches > -1)
110            query_image_results.append(num_matches)
111        return np.argmax(query_image_results)

Implementation of SuperGlue matcher with SuperPoint extractor.

SuperGlue( path_to_sp_weights, path_to_sg_weights, resize=(640, 480), resize_float=False)
62    def __init__(
63        self,
64        path_to_sp_weights,
65        path_to_sg_weights,
66        resize=(640, 480),
67        resize_float=False,
68    ):
69        self.device = "cuda" if torch.cuda.is_available() else "cpu"
70        print('Running inference on device "{}"'.format(self.device))
71
72        self.super_point = SuperPoint(path_to_sp_weights).eval().to(self.device)
73        self.super_glue_matcher = (
74            SuperGlueMatcher(path_to_sg_weights).eval().to(self.device)
75        )
76        self.resize = resize
77        self.resize_float = resize_float
device
super_point
super_glue_matcher
resize
resize_float
def get_database_features(self, database: vprdb.core.database.Database):
79    def get_database_features(self, database: Database):
80        """
81        Gets database RGB images SuperPoint features
82        :param database: Database for getting features
83        :return: Features for database images
84        """
85        features = []
86        for image in tqdm(database.color_images):
87            inp = read_image(image.path, self.device, self.resize, self.resize_float)
88            with torch.no_grad():
89                features_for_query = self.super_point({"image": inp})
90            features.append(features_for_query)
91        return np.asarray(features)

Gets database RGB images SuperPoint features

Parameters
  • database: Database for getting features
Returns

Features for database images

def match_feature(self, query_feature, db_features):
 93    def match_feature(self, query_feature, db_features):
 94        """
 95        Matches query feature with database features
 96        :param query_feature: Feature for matching
 97        :param db_features: Database features
 98        :return: Index of matched image from database
 99        """
100        query_image_results = []
101        for db_index, db_feature in enumerate(db_features):
102            pred = {k + "0": v for k, v in query_feature.items()}
103            pred = {**pred, **{k + "1": v for k, v in db_feature.items()}}
104            with torch.no_grad():
105                pred = self.super_glue_matcher(pred, self.resize)
106            pred = {k: v[0].cpu().numpy() for k, v in pred.items()}
107
108            matches = pred["matches0"]
109            num_matches = np.sum(matches > -1)
110            query_image_results.append(num_matches)
111        return np.argmax(query_image_results)

Matches query feature with database features

Parameters
  • query_feature: Feature for matching
  • db_features: Database features
Returns

Index of matched image from database