"""
The main widget for the interactive registration part of scikit-surgeryFRED
"""
from random import shuffle
import matplotlib.pyplot as plt
from sksurgeryfredmatplotlib.algorithms.ablation import Ablator
from sksurgeryfredmatplotlib.logging.fred_logger import Logger
from sksurgeryfredmatplotlib.widgets.fred_common import FredCommon
[docs]class RegistrationGame(FredCommon):
"""
an interactive window for doing live registration
"""
def __init__(self, image_file_name, headless=False):
"""
Creates a visualisation of the projected and
detected screen points, which you can click on
to measure distances
"""
super().__init__(image_file_name, headless)
self.stats_plot.set_visibilities(True, True, False, False, False,
True, True, True, True)
self.state_string = 'Actual TRE'
self.repeats = 20
self.visibility_setter = VisibilitySettings(self.repeats - 4)
self.total_score = 0
self.stats_plot.update_last_score(0)
self.stats_plot.update_total_score(self.total_score)
self.plotter.show_actual_positions = False
log_config = {"logger" : {
"log file name" : "fred_game.log",
"overwrite existing" : False
}}
self.logger = Logger(log_config)
self.ablation = Ablator(margin=1.0)
self.initialise_registration()
plt.rcParams['keymap.all_axes'].remove('a')
_ = self.fig.canvas.mpl_connect('key_press_event',
self.keypress_event)
plt.show()
[docs] def keypress_event(self, event):
"""
handle a key press event
"""
if event.key == "up":
margin = self.ablation.increase_margin()
self.stats_plot.update_margin_stats(margin)
self.fig.canvas.draw()
if event.key == "down":
margin = self.ablation.decrease_margin()
self.stats_plot.update_margin_stats(margin)
self.fig.canvas.draw()
if event.key == "a":
reg_ok, est_target = self.pbr.get_transformed_target()
if reg_ok:
score = self.ablation.ablate(est_target)
if score is not None:
self.stats_plot.update_last_score(score)
self.total_score += score
self.stats_plot.update_total_score(self.total_score)
self.logger.log_score(self.state_string, score)
if self.repeats > 1:
if self.repeats < 18:
[fids_text, tre_text, exp_tre_text, exp_fre_text,
fre_text, score_text, total_score_text,
margin_text, repeats_text, self.state_string] = \
self.visibility_setter.get_vis_state()
self.stats_plot.set_visibilities(
fids_text, tre_text, exp_tre_text, exp_fre_text,
fre_text, score_text, total_score_text,
margin_text, repeats_text)
self.repeats -= 1
self.stats_plot.update_repeats(self.repeats)
self.initialise_registration()
else:
self._game_over()
self.fig.canvas.draw()
def _game_over(self):
props = dict(boxstyle='round', facecolor='wheat', alpha=1.0)
self.fig.text(0.2, 0.7, "Game Over",
fontsize=56, bbox=props)
text_str = ("Thanks for playing.\n" +
"Please let me know your scores by sending the log file\n" +
"'fred_game.log' and any comments to s.thompson@ucl.ac.uk")
self.fig.text(0.2, 0.4, text_str,
fontsize=26, bbox=props)
self.fig.canvas.draw()
[docs] def initialise_registration(self):
"""
sets up the registration
"""
target_point = super().init_reg()
self.ablation.setup(target=target_point,
target_radius=10.0)
self.stats_plot.update_margin_stats(self.ablation.margin)
self.stats_plot.update_repeats(self.repeats)
self.fig.canvas.draw()
[docs]class VisibilitySettings:
"""
randomly selects from list of visilities, has five states
FLE and no fids
Expected FRE
Expected TRE
Actual FRE
"""
def __init__(self, buffer_size):
"""
:params buffer_size: the number of repeats you want, should be a
product of 4
"""
if buffer_size % 4 != 0:
raise ValueError("Buffer size must be divisible by 4")
each_bin = int(buffer_size / 4)
fle_and_fids = [True, False, False, False, False,
True, True, True, True, 'FLE and Number of Fids']
exp_tre = [False, False, True, False, False, True, True, True, True,
'Expected TRE']
exp_fre = [False, False, False, True, False, True, True, True, True,
'Expected FRE']
actual_fre = [False, False, False, False, True, True, True, True, True,
'Actual FRE']
self.state_list = []
for _ in range(each_bin):
self.state_list.append(fle_and_fids)
self.state_list.append(exp_tre)
self.state_list.append(exp_fre)
self.state_list.append(actual_fre)
[docs] def get_vis_state(self):
"""
returns a random visibility state
"""
shuffle(self.state_list)
try:
return self.state_list.pop()
except IndexError:
raise IndexError("You tried to get a value from" +
"VisibilitySettings, but" +
"the buffer is emptied.") from IndexError