Properly splitting YOLO3 model output name to obtain 3 variables and get the best model

Issue

I have a lot of files, which are a result of machine learning with YOLO model, generated by Tensorflow.

Each filename is named:

detection_model-ex-013--loss-0016.228.h5

With the only differences, being:

013 – epoch/generationNumber

0016 – loss (kind of accuracy, but from training PoV)

228 – additional precision of loss (I think, but I’m not sure yet, so I consider it as separate for now)

It would be pretty simple, but it’s a combination of dashes, underscores and dots, given that I’m pretty new to python and what it’s capable of, I couldn’t find any solution that would fit this filename and I’m having difficulties writing a regexp for this.

For now my python "logic" selects the model solely based on which file was the last one saved, which is a start, but it’s far from ideal.

def findBestModel():  # "best" model is currently just the last model in directory
last_model = max(glob.glob('data/models/*.h5'), key=os.path.getmtime)
print('Selected model: ' + last_model)
return last_model

What would be the right way to actually make findBestModel() return those 3 variables that I need, from each filename, without making it overcomplicated?

Solution

You could use re.match to return a match object from each of the files you already obtained with glob.glob. The use of re.match with named groups will result in a dictionary with 3 separated values from your file name: epoch, loss and precision loss. The files can then be sorted using the key parameter from sorted using the loss value (int(x.group('loss')) returned by the regex. The match object has the attribute string that can be used to return the string passed to match(), which will correspond to the path to your best model.

data
└── models
    ├── detection_model-ex-002--loss-3416.127.h5
    ├── detection_model-ex-013--loss-0016.228.h5
    ├── detection_model-ex-2462--loss-0173.093.h5
    ├── detection_model-ex-486--loss-1933.981.h5
    └── detection_model-ex-83713--loss-0001.048.h5
import glob
import re

def findBestModel():
    models = []
    for model in glob.glob('./data/models/*.h5'):
        m = re.match(
            r'.*?-(?P<epoch>\d+)--' +
            r'loss-(?P<loss>\d+)\.' +
            r'(?P<precision>\d+)\.h5'
        , model)
        models.append(m)

    sorted_models = sorted(models, key=lambda x:int(x.group('loss')))
    return sorted_models[0].string

best_model = findBestModel()
print(best_model) # './data/models/detection_model-ex-83713--loss-0001.048.h5'

Answered By – n1colas.m

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published