Fusion Model
Contents
Fusion Model¶
Data Example¶
from models import ModalityPredictorPCA, MODELTYPE
from generate_random_input import generate_fusion_image_input
df = generate_fusion_image_input()
df["subset"] = ["TRAIN_VALIDATE"]*int(df.shape[0]/2) + ["TEST"]*int(df.shape[0]/2)
df
data_image_1 | data_image_2 | age | gender | subset | |
---|---|---|---|---|---|
0 | [47.90268504737213, 55.605919219345566, 17.935... | [0.47169120780684504, 0.7675746276459315, 0.92... | 64 | M | TRAIN_VALIDATE |
1 | [67.36314460702869, 23.11991970043677, 12.7263... | [0.25488659747587705, 0.8747048208394343, 0.10... | 69 | M | TRAIN_VALIDATE |
2 | [27.779335170284362, 8.69078805515224, 16.5922... | [0.8004833246345652, 0.5331968388732692, 0.438... | 28 | F | TRAIN_VALIDATE |
3 | [42.58482520906333, 38.214889508728895, 48.448... | [0.8525791325173845, 0.7134357681855932, 0.422... | 89 | M | TRAIN_VALIDATE |
4 | [35.93036363590782, 31.917923604190253, 29.876... | [0.5758026759371442, 0.27679724046350107, 0.40... | 43 | F | TRAIN_VALIDATE |
... | ... | ... | ... | ... | ... |
495 | [44.86351359507677, 6.09766310076842, 3.646045... | [0.14416804663111982, 0.7082797149926434, 0.68... | 80 | F | TEST |
496 | [12.366117329906821, 22.60410325294449, 17.973... | [0.2228989522235434, 0.6819998779459963, 0.076... | 44 | M | TEST |
497 | [54.178132638627815, 22.0459855284983, 36.1001... | [0.15068600550198197, 0.6247157139396373, 0.03... | 69 | M | TEST |
498 | [33.35691101796781, 6.154465738131789, 16.9540... | [0.006046339867387784, 0.11503217773837293, 0.... | 42 | F | TEST |
499 | [25.88994784520547, 32.937509813732596, 14.926... | [0.18396047157757056, 0.8423439769892563, 0.87... | 46 | F | TEST |
500 rows × 5 columns
Model¶
from sklearn import set_config
set_config(display="diagram")
number_components = 2
predictor = ModalityPredictorPCA(df, ["data_image_1", "data_image_2"], MODELTYPE.FUSION, 5)
model = predictor.get_fusion_at_model_level(number_components)
model
StackingRegressor(estimators=[('regressor_data_image_1', Pipeline(steps=[('preprocessor', ColumnTransformer(transformers=[('dimensionality_reduction', Pipeline(steps=[('flatten', FlattenNestedArray()), ('dimensionality_reduction', PCA(n_components=2, svd_solver='full')), ('scaler_pre', StandardScaler())]), 'data_image_1'), ('gender_and_site_encoded', OneHotEnc... ColumnTransformer(transformers=[('dimensionality_reduction', Pipeline(steps=[('flatten', FlattenNestedArray()), ('dimensionality_reduction', PCA(n_components=2, svd_solver='full')), ('scaler_pre', StandardScaler())]), 'data_image_2'), ('gender_and_site_encoded', OneHotEncoder(handle_unknown='ignore'), ['gender'])])), ('regressor', EMRVR())]))], final_estimator=EMRVR())Please rerun this cell to show the HTML repr or trust the notebook.
StackingRegressor(estimators=[('regressor_data_image_1', Pipeline(steps=[('preprocessor', ColumnTransformer(transformers=[('dimensionality_reduction', Pipeline(steps=[('flatten', FlattenNestedArray()), ('dimensionality_reduction', PCA(n_components=2, svd_solver='full')), ('scaler_pre', StandardScaler())]), 'data_image_1'), ('gender_and_site_encoded', OneHotEnc... ColumnTransformer(transformers=[('dimensionality_reduction', Pipeline(steps=[('flatten', FlattenNestedArray()), ('dimensionality_reduction', PCA(n_components=2, svd_solver='full')), ('scaler_pre', StandardScaler())]), 'data_image_2'), ('gender_and_site_encoded', OneHotEncoder(handle_unknown='ignore'), ['gender'])])), ('regressor', EMRVR())]))], final_estimator=EMRVR())
ColumnTransformer(transformers=[('dimensionality_reduction', Pipeline(steps=[('flatten', FlattenNestedArray()), ('dimensionality_reduction', PCA(n_components=2, svd_solver='full')), ('scaler_pre', StandardScaler())]), 'data_image_1'), ('gender_and_site_encoded', OneHotEncoder(handle_unknown='ignore'), ['gender'])])
data_image_1
FlattenNestedArray()
PCA(n_components=2, svd_solver='full')
StandardScaler()
['gender']
OneHotEncoder(handle_unknown='ignore')
EMRVR()
ColumnTransformer(transformers=[('dimensionality_reduction', Pipeline(steps=[('flatten', FlattenNestedArray()), ('dimensionality_reduction', PCA(n_components=2, svd_solver='full')), ('scaler_pre', StandardScaler())]), 'data_image_2'), ('gender_and_site_encoded', OneHotEncoder(handle_unknown='ignore'), ['gender'])])
data_image_2
FlattenNestedArray()
PCA(n_components=2, svd_solver='full')
StandardScaler()
['gender']
OneHotEncoder(handle_unknown='ignore')
EMRVR()
EMRVR()