def load_state(T, chkpt_path, device, network_temp=2): net = Network(T, temp=network_temp).to(device) optimizer = Adam(net.parameters(), lr=1e-4, weight_decay=1e-4) if chkpt_path is not None and os.path.exists(chkpt_path): checkpoint = torch.load(chkpt_path, map_location=torch.device(device)) net.load_state_dict(checkpoint[MODEL_KEY]) optimizer.load_state_dict(checkpoint[OPTIMIZER_KEY]) games_trained = checkpoint[GAMES_TRAINED_KEY] replay_mem = checkpoint[REPLAY_MEM_KEY] else: games_trained = 0 replay_mem = ReplayMemory() return net, optimizer, games_trained, replay_mem
def load_pretrained_weights(args): model = Network(args.classes) model_path = args.model_path.replace('RealWorld', args.domain) pre = torch.load(args.model_path) new_pre = OrderedDict() for p in pre: if ('classifier' in p): # print('----', p) continue else: new_pre[p] = pre[p] model.load_state_dict(new_pre, strict=False) for name, p in model.state_dict().items(): if ('classifier' in name): continue else: p.requires_grad = False torch.nn.init.xavier_uniform_(model.classifier.fc8.weight) del new_pre return model