La position actuelle:Accueil du site>Le modèle de chargement pytorch n'importe que des poids partiels de couche, c'est - à - dire qu'il saute la méthode de spécification de la couche réseau

Le modèle de chargement pytorch n'importe que des poids partiels de couche, c'est - à - dire qu'il saute la méthode de spécification de la couche réseau

2022-05-15 07:14:46M0 61899108

Besoins

PytorchLors du chargement du modèle,N'importez que des poids partiels de calque,Sauter la section spécifier la couche réseau.(Les fichiers de poids sont stockés sousdictForme)

Méthode 1

Méthodes courantes:Lors du chargement des poidsifFiltrer la couche réseau

'''
# modelStructure du réseau définie pour:
class model(nn.Module):
    def __init__(self):
        super(model,self).__init__()
        ……

    def forward(self,x):
        ……
        return x
'''

model = model()  
# loadParamètres du modèle existant(Fichier de poids),Le nom du suffixe peut être différent    
pretrained_dict = torch.load('model.pkl')
model_dict = model.state_dict()
# Le fait est que,Demodel_dictLire danskey、valueHeure,AvecifFiltrer les couches réseau indésirables 
pretrained_dict = {key: value for key, value in pretrained_dict.items() if (key in model_dict and 'Prediction' not in key)}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

Méthode 2

Ne correspond pas exactement, Seuls les paramètres présents dans le poids sont chargés , Sauter si ça ne correspond pas

# load_state_dict() Par défautstrict=True,Besoin d'une correspondance parfaite,Sinon, une erreur est signalée.
# Modifier comme suit:strict=FalseAprès, Ne correspond qu'aux paramètres existants 
pretrained_dict = torch.load(weight_path)
model.load_state_dict(pretrained_dict, strict=False)

Méthode III

  Ne pas utiliser le fichier de poids original pour la formation , Copie du fichier de poids original , La copie ne contient que les couches réseau requises , Ensuite, les fichiers de poids de copie sont directement utilisés pour la formation .

    #  Copie du fichier de poids original , La copie ne contient que les couches réseau requises ,
    #  Formation ultérieure directement à l'aide de fichiers copiés .
    import pickle

    model = model()
    net = model
    path_weight = 'R-50.pkl'
    path_weight2 = 'R2-50.pkl'

    with open(path_weight,'rb') as f:
        obj=f.read()
    # Avecpickle.loads() Informations sur le poids de chargement 
    la_obj=pickle.loads(obj,encoding='latin1')
    # AvecifFiltrer
    weights= {key: value for key, value in la_obj.items()}
              #if key in la_obj and 'backbone.bottom_up.stem.conv1.weight' not in key}
    # Utiliserprint Voir les informations du fichier de poids  
    print(weights)
    
    #  Copie en profondeur d'un autre document 
    state_dict = copy.deepcopy(weights)
    with open(path_weight2,'wb') as f2:
        pickle.dump(state_dict, f2)

    # Peut écriretxt, Facilité d'accès aux informations 
    path_weight2 = 'R2-101.txt'
    inf = str(state_dict)
    ff = open(path_weight2,'w')
    ff.write(inf)

Voici les exigences particulières pour l'optimisation des paramètres de chargement :Paramètres fixes、 Ou la vitesse de mise à jour des paramètres est différente .

Méthode IV

Si vous chargez ces paramètres , Certains paramètres ne doivent pas être mis à jour , C'est - à - dire fixe ,Ne pas participer à la formation, Les attributs de gradient que vous devez définir manuellement pour ces paramètres sont: Fasle,Et dansoptimizer Filtrer ces paramètres lors du transfert des paramètres :

#  Après avoir chargé les paramètres du modèle de pré - Formation ...
for name, value in model.named_parameters():
    if name Certaines conditions sont remplies:
        value.requires_grad = False

# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)

Méthode V

Si vous chargez ces paramètres , Tous les paramètres sont mis à jour , Mais la vitesse de mise à jour de certains paramètres et d'autres est nécessaire (Taux d'apprentissagelearning rate)C'est différent, Il est préférable de savoir quels sont les noms de ces paramètres :

#  Après avoir chargé les paramètres du modèle de pré - Formation ...
for name, value in model.named_parameters():
    print(name)
# Ou
print(model.state_dict().keys())

Supposons qu'il y ait encoder,viewerEtdecoderDeux parties, Les noms des paramètres sont respectivement :

'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',

Supposons que les exigencesencode、viewerLe taux d'apprentissage de1e-6, decoderLe taux d'apprentissage de1e-4, Lors du passage des paramètres dans l'optimiseur :

ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
                              {'params':model.decoder.parameters()}
                              ],
                              lr=1e-4, momentum=0.9)

Le résultat du Code est de diviser decoderParamètrelearning_rate=1e-4 Extérieur, Autres paramètres learning_rate=1e-6.
En coursoptimizerHeure, Et la méthode générale de transmission des paramètres torch.optim.Adam(model.parameters(), lr=xxx) C'est différent., La section paramètres utilise un list, list Chaque élément de paramsEtlr Deux valeurs clés .Si ce n'est pas le cas lrPuis appliquerAdamDelrPropriétés.Adam Sauf que lr, Tout le reste est partagé par les paramètres (Par exemple,momentum).
 

Problèmes rencontrés

torch.load ChargementFichier de poidsErreur de temps Magic Number Error 

Parfois, on utilise torch.load Chargement Fichiers de poids plus anciens Erreur possible Magic Number Error, Cela peut être dû au fichier Utiliser pickle Stocké et encodé en utilisant latin1, Vous pouvez maintenant charger :

Pour filtrer , La même chose peut être ajoutée après ifPorter un jugement.

import pickle
with open(weights_path, 'rb') as f:
    obj = f.read()
# AvecpickleEn coursload,Le mode de codage estlatin1
weights = {key: weight_dict for key, weight_dict in pickle.loads(obj,encoding='latin1').items()}
# Même chose.,Ça marcheif Le jugement passe au crible 
# weights = {key: value for key, value in pickle.loads(obj,encoding='latin1').items() if (key in model_dict and 'Prediction' not in key)}
model.load_state_dict(weights) 

TypeError: a bytes-like object is required, not 'str'

python3Etpython2Il y a une différence dans le décodage de la valeur de retour de socket.

La prise est socket,Utilisé pour décrire IP Adresse et Port, L'application envoie une demande au réseau par socket ou répond à une demande réseau , Une interface de données qui peut être considérée comme un réseau informatique . Il existe actuellement deux types de sockets : Basé sur des fichiers et des réseaux .

Solutions

En utilisant la fonction encode() Etdecode():

  1. str Adoption encode() Le Code de fonction devient bytes
  2. bytes Adoption decode() Le Code de fonction devient str.( Lorsque nous lisons un flux d'octets à partir d'un réseau ou d'un disque , Et les données que j'ai lues sont bytes)

Supplément:

str --> bytes

# Déclarez une chaînes:
>>> s = 'abc'
>>> type(s)
<class 'str'>

#  Quatre modes de conversion :
>>> b1 = s.encode()
>>> type(b1)
<class 'bytes'>
>>> b2 = str.encode(s)
>>> type(b2)
<class 'bytes'>
>>> b3 = s.encode(encoding='utf-8')
>>> type(b3)
<class 'bytes'>
>>> b4 = bytes(s,encoding='utf-8')
>>> type(b4)
<class 'bytes'>

bytes --> str

# Déclarez unbytes:
>>> b = b'abc'
>>> type(b)
<class 'bytes'>

# Trois modes de conversion:
>>> s1 = bytes.decode(b)
>>> type(s1)
<class 'str'>
>>> s2 = b.decode()
>>> type(s2)
<class 'str'>
>>> s3 = str(b,encoding='utf-8')
>>> type(s3)
<class 'str'>

Blog de référence

Pytorch Méthode d'importation de poids de couche partiels seulement _ Le blog de Ximeng Lihai -CSDNBlogs_pytorch Chargement des poids partiels

pytorchAffiner le modèle— Ne charger que certaines couches du modèle de pré - Formation _Farmer Mountain Spring2Le blog de-CSDNBlogs

Pytorch Le modèle de chargement ne correspond pas exactement & Chargement des poids partiels des paramètres seulement load_hxxjxwBlog de-CSDNBlogs_pytorch Le modèle de chargement ne correspond pas au saut

pytorch Après avoir chargé le modèle de pré - formation , Je veux juste m'entraîner à chaque niveau. ?_ Mubai. -Blog de-CSDNBlogs_pytorch Former seulement le dernier niveau

PyTorch | Enregistrer et charger le modèle - Oui. (zhihu.com)

torch.load Erreur lors du chargement du poids Magic Number Error - Regarder la fraîcheur des joueurs haut de gamme - La blogosphère (cnblogs.com)

PythonErreur signalée:TypeError: a bytes-like object is required, not ‘str‘_ Le blog de la troisième soeur de programmeur -CSDNBlogs 

Mentions de copyright
Auteur de cet article [M0 61899108],Réimpression s’il vous plaît apporter le lien vers l’original, merci
https://fra.chowdera.com/2022/135/202205142322539438.html

Recommandé au hasard