def test_checkpointing(self): """Confirm that different checkpoints are being saved with checkpoint_every on""" em = EndModel( seed=1, batchnorm=False, dropout=0.0, layer_out_dims=[2, 10, 2], verbose=False, ) Xs, Ys = self.single_problem em.train_model( (Xs[0], Ys[0]), valid_data=(Xs[1], Ys[1]), n_epochs=5, checkpoint=True, checkpoint_every=1, ) test_model = copy.deepcopy(em.state_dict()) new_model = torch.load("checkpoints/model_checkpoint_4.pth") self.assertFalse( torch.all( torch.eq( test_model["network.1.0.weight"], new_model["model"]["network.1.0.weight"], ))) new_model = torch.load("checkpoints/model_checkpoint_5.pth") self.assertTrue( torch.all( torch.eq( test_model["network.1.0.weight"], new_model["model"]["network.1.0.weight"], )))
def train_model(args): #global args #args = parser.parse_args() hidden_size = 128 num_classes = 2 encode_dim = 1000 # using get_frm_output_size() L,Y = load_labels(args) data_list = {} data_list["dev"] = glob(args.dev + '/la_4ch/*.npy') data_list["test"] = glob(args.test + '/la_4ch/*.npy') # End Model # Create datasets and dataloaders dev, test = load_dataset(data_list, Y) data_loader = get_data_loader(dev, test, args.batch_size, args.num_workers) #print(len(data_loader["dev"])) # 1500 / batch_size #print(len(data_loader["test"])) # 1000 / batch_size #import ipdb; ipdb.set_trace() # Define input encoder cnn_encoder = FrameEncoderOC if(torch.cuda.is_available()): device = 'cuda' else: device = 'cpu' # Define LSTM module lstm_module = LSTMModule( encode_dim, hidden_size, bidirectional=False, verbose=False, lstm_reduction="attention", encoder_class=cnn_encoder ) init_kwargs = { "layer_out_dims":[hidden_size, num_classes], "input_module": lstm_module, "optimizer": "adam", "verbose": False, "input_batchnorm": False, "use_cuda":cuda, 'seed':args.seed, 'device':device} end_model = EndModel(**init_kwargs) if not os.path.exists(args.checkpoint_dir): os.mkdir(args.checkpoint_dir) with open(args.checkpoint_dir+'/init_kwargs.pickle', "wb") as f: pickle.dump(init_kwargs,f,protocol=pickle.HIGHEST_PROTOCOL) # Train end model end_model.train_model( train_data=data_loader["dev"], valid_data=data_loader["test"], l2=args.weight_decay, lr=args.lr, n_epochs=args.n_epochs, log_train_every=1, verbose=True, progress_bar = True, loss_weights = [0.55,0.45], batchnorm = args.batchnorm, middle_dropout = args.dropout, checkpoint = False, #checkpoint_every = args.n_epochs, #checkpoint_best = False, #checkpoint_dir = args.checkpoint_dir, #validation_metric='f1', ) # evaluate end model end_model.score(data_loader["test"], verbose=True,metric=['accuracy','precision', 'recall', 'f1','roc-auc','ndcg']) #end_model.score((Xtest,Ytest), verbose=True, metric=['accuracy','precision', 'recall', 'f1']) # saving model state = { "model": end_model.state_dict(), # "optimizer": optimizer.state_dict(), # "lr_scheduler": lr_scheduler.state_dict() if lr_scheduler else None, "score": end_model.score(data_loader["test"],verbose=False,metric=['accuracy','precision', 'recall', 'f1','roc-auc','ndcg']) } checkpoint_path = f"{args.checkpoint_dir}/best_model.pth" torch.save(state, checkpoint_path)