Fusion Model

Fusion Model

Data Example

from models import ModalityPredictorPCA, MODELTYPE
from generate_random_input import generate_fusion_image_input


df = generate_fusion_image_input()
df["subset"] = ["TRAIN_VALIDATE"]*int(df.shape[0]/2) + ["TEST"]*int(df.shape[0]/2)

df
data_image_1 data_image_2 age gender subset
0 [47.90268504737213, 55.605919219345566, 17.935... [0.47169120780684504, 0.7675746276459315, 0.92... 64 M TRAIN_VALIDATE
1 [67.36314460702869, 23.11991970043677, 12.7263... [0.25488659747587705, 0.8747048208394343, 0.10... 69 M TRAIN_VALIDATE
2 [27.779335170284362, 8.69078805515224, 16.5922... [0.8004833246345652, 0.5331968388732692, 0.438... 28 F TRAIN_VALIDATE
3 [42.58482520906333, 38.214889508728895, 48.448... [0.8525791325173845, 0.7134357681855932, 0.422... 89 M TRAIN_VALIDATE
4 [35.93036363590782, 31.917923604190253, 29.876... [0.5758026759371442, 0.27679724046350107, 0.40... 43 F TRAIN_VALIDATE
... ... ... ... ... ...
495 [44.86351359507677, 6.09766310076842, 3.646045... [0.14416804663111982, 0.7082797149926434, 0.68... 80 F TEST
496 [12.366117329906821, 22.60410325294449, 17.973... [0.2228989522235434, 0.6819998779459963, 0.076... 44 M TEST
497 [54.178132638627815, 22.0459855284983, 36.1001... [0.15068600550198197, 0.6247157139396373, 0.03... 69 M TEST
498 [33.35691101796781, 6.154465738131789, 16.9540... [0.006046339867387784, 0.11503217773837293, 0.... 42 F TEST
499 [25.88994784520547, 32.937509813732596, 14.926... [0.18396047157757056, 0.8423439769892563, 0.87... 46 F TEST

500 rows × 5 columns

Model

from sklearn import set_config
set_config(display="diagram")

number_components = 2
predictor = ModalityPredictorPCA(df, ["data_image_1", "data_image_2"], MODELTYPE.FUSION,  5)

model = predictor.get_fusion_at_model_level(number_components)
model
StackingRegressor(estimators=[('regressor_data_image_1',
                               Pipeline(steps=[('preprocessor',
                                                ColumnTransformer(transformers=[('dimensionality_reduction',
                                                                                 Pipeline(steps=[('flatten',
                                                                                                  FlattenNestedArray()),
                                                                                                 ('dimensionality_reduction',
                                                                                                  PCA(n_components=2,
                                                                                                      svd_solver='full')),
                                                                                                 ('scaler_pre',
                                                                                                  StandardScaler())]),
                                                                                 'data_image_1'),
                                                                                ('gender_and_site_encoded',
                                                                                 OneHotEnc...
                                                ColumnTransformer(transformers=[('dimensionality_reduction',
                                                                                 Pipeline(steps=[('flatten',
                                                                                                  FlattenNestedArray()),
                                                                                                 ('dimensionality_reduction',
                                                                                                  PCA(n_components=2,
                                                                                                      svd_solver='full')),
                                                                                                 ('scaler_pre',
                                                                                                  StandardScaler())]),
                                                                                 'data_image_2'),
                                                                                ('gender_and_site_encoded',
                                                                                 OneHotEncoder(handle_unknown='ignore'),
                                                                                 ['gender'])])),
                                               ('regressor', EMRVR())]))],
                  final_estimator=EMRVR())
Please rerun this cell to show the HTML repr or trust the notebook.