La position actuelle:Accueil du site>Le modèle de chargement pytorch n'importe que des poids partiels de couche, c'est - à - dire qu'il saute la méthode de spécification de la couche réseau
Le modèle de chargement pytorch n'importe que des poids partiels de couche, c'est - à - dire qu'il saute la méthode de spécification de la couche réseau
2022-05-15 07:14:46【M0 61899108】
Besoins
PytorchLors du chargement du modèle,N'importez que des poids partiels de calque,Sauter la section spécifier la couche réseau.(Les fichiers de poids sont stockés sousdictForme)
Méthode 1
Méthodes courantes:Lors du chargement des poidsifFiltrer la couche réseau
'''
# modelStructure du réseau définie pour:
class model(nn.Module):
def __init__(self):
super(model,self).__init__()
……
def forward(self,x):
……
return x
'''
model = model()
# loadParamètres du modèle existant(Fichier de poids),Le nom du suffixe peut être différent
pretrained_dict = torch.load('model.pkl')
model_dict = model.state_dict()
# Le fait est que,Demodel_dictLire danskey、valueHeure,AvecifFiltrer les couches réseau indésirables
pretrained_dict = {key: value for key, value in pretrained_dict.items() if (key in model_dict and 'Prediction' not in key)}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
Méthode 2
Ne correspond pas exactement, Seuls les paramètres présents dans le poids sont chargés , Sauter si ça ne correspond pas
# load_state_dict() Par défautstrict=True,Besoin d'une correspondance parfaite,Sinon, une erreur est signalée.
# Modifier comme suit:strict=FalseAprès, Ne correspond qu'aux paramètres existants
pretrained_dict = torch.load(weight_path)
model.load_state_dict(pretrained_dict, strict=False)
Méthode III
Ne pas utiliser le fichier de poids original pour la formation , Copie du fichier de poids original , La copie ne contient que les couches réseau requises , Ensuite, les fichiers de poids de copie sont directement utilisés pour la formation .
# Copie du fichier de poids original , La copie ne contient que les couches réseau requises ,
# Formation ultérieure directement à l'aide de fichiers copiés .
import pickle
model = model()
net = model
path_weight = 'R-50.pkl'
path_weight2 = 'R2-50.pkl'
with open(path_weight,'rb') as f:
obj=f.read()
# Avecpickle.loads() Informations sur le poids de chargement
la_obj=pickle.loads(obj,encoding='latin1')
# AvecifFiltrer
weights= {key: value for key, value in la_obj.items()}
#if key in la_obj and 'backbone.bottom_up.stem.conv1.weight' not in key}
# Utiliserprint Voir les informations du fichier de poids
print(weights)
# Copie en profondeur d'un autre document
state_dict = copy.deepcopy(weights)
with open(path_weight2,'wb') as f2:
pickle.dump(state_dict, f2)
# Peut écriretxt, Facilité d'accès aux informations
path_weight2 = 'R2-101.txt'
inf = str(state_dict)
ff = open(path_weight2,'w')
ff.write(inf)
Voici les exigences particulières pour l'optimisation des paramètres de chargement :Paramètres fixes、 Ou la vitesse de mise à jour des paramètres est différente .
Méthode IV
Si vous chargez ces paramètres , Certains paramètres ne doivent pas être mis à jour , C'est - à - dire fixe ,Ne pas participer à la formation, Les attributs de gradient que vous devez définir manuellement pour ces paramètres sont: Fasle,Et dansoptimizer Filtrer ces paramètres lors du transfert des paramètres :
# Après avoir chargé les paramètres du modèle de pré - Formation ...
for name, value in model.named_parameters():
if name Certaines conditions sont remplies:
value.requires_grad = False
# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)
Méthode V
Si vous chargez ces paramètres , Tous les paramètres sont mis à jour , Mais la vitesse de mise à jour de certains paramètres et d'autres est nécessaire (Taux d'apprentissagelearning rate)C'est différent, Il est préférable de savoir quels sont les noms de ces paramètres :
# Après avoir chargé les paramètres du modèle de pré - Formation ...
for name, value in model.named_parameters():
print(name)
# Ou
print(model.state_dict().keys())
Supposons qu'il y ait encoder,viewerEtdecoderDeux parties, Les noms des paramètres sont respectivement :
'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',
Supposons que les exigencesencode、viewerLe taux d'apprentissage de1e-6, decoderLe taux d'apprentissage de1e-4, Lors du passage des paramètres dans l'optimiseur :
ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
{'params':model.decoder.parameters()}
],
lr=1e-4, momentum=0.9)
Le résultat du Code est de diviser decoderParamètrelearning_rate=1e-4 Extérieur, Autres paramètres learning_rate=1e-6.
En coursoptimizerHeure, Et la méthode générale de transmission des paramètres torch.optim.Adam(model.parameters(), lr=xxx) C'est différent., La section paramètres utilise un list, list Chaque élément de paramsEtlr Deux valeurs clés .Si ce n'est pas le cas lrPuis appliquerAdamDelrPropriétés.Adam Sauf que lr, Tout le reste est partagé par les paramètres (Par exemple,momentum).
Problèmes rencontrés
torch.load ChargementFichier de poidsErreur de temps Magic Number Error
Parfois, on utilise torch.load Chargement Fichiers de poids plus anciens Erreur possible Magic Number Error, Cela peut être dû au fichier Utiliser pickle Stocké et encodé en utilisant latin1, Vous pouvez maintenant charger :
Pour filtrer , La même chose peut être ajoutée après ifPorter un jugement.
import pickle
with open(weights_path, 'rb') as f:
obj = f.read()
# AvecpickleEn coursload,Le mode de codage estlatin1
weights = {key: weight_dict for key, weight_dict in pickle.loads(obj,encoding='latin1').items()}
# Même chose.,Ça marcheif Le jugement passe au crible
# weights = {key: value for key, value in pickle.loads(obj,encoding='latin1').items() if (key in model_dict and 'Prediction' not in key)}
model.load_state_dict(weights)
TypeError: a bytes-like object is required, not 'str'
python3Etpython2Il y a une différence dans le décodage de la valeur de retour de socket.
La prise est socket,Utilisé pour décrire IP Adresse et Port, L'application envoie une demande au réseau par socket ou répond à une demande réseau , Une interface de données qui peut être considérée comme un réseau informatique . Il existe actuellement deux types de sockets : Basé sur des fichiers et des réseaux .
Solutions:
En utilisant la fonction encode() Etdecode():
- str Adoption encode() Le Code de fonction devient bytes
- bytes Adoption decode() Le Code de fonction devient str.( Lorsque nous lisons un flux d'octets à partir d'un réseau ou d'un disque , Et les données que j'ai lues sont bytes)
Supplément:
str --> bytes
# Déclarez une chaînes:
>>> s = 'abc'
>>> type(s)
<class 'str'>
# Quatre modes de conversion :
>>> b1 = s.encode()
>>> type(b1)
<class 'bytes'>
>>> b2 = str.encode(s)
>>> type(b2)
<class 'bytes'>
>>> b3 = s.encode(encoding='utf-8')
>>> type(b3)
<class 'bytes'>
>>> b4 = bytes(s,encoding='utf-8')
>>> type(b4)
<class 'bytes'>
bytes --> str
# Déclarez unbytes:
>>> b = b'abc'
>>> type(b)
<class 'bytes'>
# Trois modes de conversion:
>>> s1 = bytes.decode(b)
>>> type(s1)
<class 'str'>
>>> s2 = b.decode()
>>> type(s2)
<class 'str'>
>>> s3 = str(b,encoding='utf-8')
>>> type(s3)
<class 'str'>
Blog de référence
PyTorch | Enregistrer et charger le modèle - Oui. (zhihu.com)
Mentions de copyright
Auteur de cet article [M0 61899108],Réimpression s’il vous plaît apporter le lien vers l’original, merci
https://fra.chowdera.com/2022/135/202205142322539438.html
Recommandé par sidebar
- Erreur de résolution de régression logique: valueerror: Solver lbfgs support only 'L2' or 'none' Penalties, got L1 penalty.
- Oracle OCI Computing, Storage, Network Tools designed to reduce Cloud Complexity
- Journal de bord du projet go [11e mise à jour du projet de la Décennie open source]
- Variables et opérateurs de script Shell
- Parler et trouver un emploi
- C'est la capacité, c'est la culture.
- Tensorflow Learning notes (5)
- Vitest prend en charge le workaround de cjs (scénario commun JS du produit Typescript)
- La réinstallabilité et l'équité des serrures de verrouillage dans les séries de programmation simultanées
- Discussion sur la relation entre Fiori Fundamentals et SAP ui5 Web Components
Devinez que vous aimez
RAM / FIFO Learning Review
La dernière version de 2022 est la version industrielle et commerciale ERP M7 V22. 0 version réseau du logiciel de gestion de la production financière pour l'achat, la vente et le stockage - système de gestion de la fabrication de groupe dans le cloud
[apprentissage automatique 05] régression Lasso et elasticnet
Raccourci idea
Recherche sur la création de fenêtres modales et non modales
[test de performance] chapitre 5 | installation de l'environnement jmeter
Guide d'utilisation de matplotlib, 100 cas du début à la fin! (code source joint)
Dots + Interval stats and geoms
Sigir2022 | recommandation de session basée sur les préférences des utilisateurs en matière de prix et d'intérêts
Cloudreve auto - construit Cloud disk Real station: la capacité et la vitesse sont déterminées par vous - même
Recommandé au hasard
- Construire un pool d'agents gratuits en utilisant la fonction Tencent Cloud
- Installation de redis et types de données de base
- Effet du graphique de rotation JS, mise en œuvre progressive de la transparence
- [stack + deep First Search] Bracket issue Summary
- Notes chapitre 1 Flux et fichiers (6) lecture aléatoire des fichiers et lecture zip des fichiers
- Votre base de données porte - t - elle vraiment un gilet pare - balles?
- Expérience 4 synchronisation et communication des processus
- Leetcode Tencent Selected Practice 50 questions - 557. Inverser le mot III dans la chaîne
- Questions de simulation complètes (y compris les réponses et l'analyse) pour l'Ingénieur en gestion de projet de l'intégration du système d'examen souple
- Cette API Alibaba TENCENT est utilisée pour gérer les artefacts et résoudre les problèmes de documentation
- La base théorique la plus solide du micro - service peut être considérée comme une excellente méthode mentale
- L'Octo, en tant que cadre de communication de service de haute performance du Groupe, peut - il être considéré comme un atout?
- Interview immersive: MySQL Serial gun, combien pouvez - vous combattre?
- Alibaba a demandé: quels modèles de limitation de courant les systèmes de l'entreprise ont - ils utilisés auparavant?
- Différences entre les logiciels ERP tels que faster software, Kingdee Software, UFIDA software, majordome software, dingjie Software et les logiciels d'importation et de vente
- P1439 [modèle] sous - séquence publique la plus longue
- Base de données MySQL (8): type de données - décimale
- Un homme de 38 ans qui vivait seul est mort.
- [freertos Task Recovery and pending]
- UDS - Comment réaliser la demande et la réponse du Service de diagnostic dans CAPL
- UDS - Comment implémenter la lecture du DTC et de son état dans CAPL
- Copier intelligemment tous les fichiers dans plusieurs dossiers à l'emplacement spécifié
- Clip vidéo, ajout de sous - titres SRT à une période de temps de la vidéo
- [SQL Union operator]
- Explication des principes de base des reptiles
- Huawei Device configure un VLAN Multicast Multi - à Multi - fonctions basé sur le VLAN utilisateur
- Copier la liste liée avec un pointeur aléatoire < facteur de difficulté>
- Quels sont les éléments exonérés de l'impôt sur le revenu des particuliers?
- NPM warn read shrinkwrap this version of NPM is compatible with [email protected] , mais package lock.
- Comment gérer la confidentialité des sources
- JVM (XVII) - - chargement du Code octet et de la classe (II) - - ensemble d'instructions du Code octet
- Anglais niveau 6 vocabulaire à haute fréquence sténographie + 2 décembre 2018 audition jour 04
- Configuration optimisée par ordinateur - win10
- Innovation Workshop Li Kaifu: la tendance la plus prometteuse de la Décennie est la technologie médicale
- Comprendre rapidement le CDN en mandarin
- L'apprentissage par petits échantillons n'est - il qu'une question d'auto - Salut universitaire?
- Déclarations communes à la base de données
- Learnopengl Learning Notes - Advanced Data
- La fonction de base de données interroge la base de données MySQL est un problème de temps correct
- Comment créer un nouveau menu en clic droit sous win10 en utilisant typora. Fichier MD