Пример #1
0
    def test_checkpointing(self):
        """Confirm that different checkpoints are being saved with checkpoint_every on"""
        em = EndModel(
            seed=1,
            batchnorm=False,
            dropout=0.0,
            layer_out_dims=[2, 10, 2],
            verbose=False,
        )
        Xs, Ys = self.single_problem
        em.train_model(
            (Xs[0], Ys[0]),
            valid_data=(Xs[1], Ys[1]),
            n_epochs=5,
            checkpoint=True,
            checkpoint_every=1,
        )
        test_model = copy.deepcopy(em.state_dict())

        new_model = torch.load("checkpoints/model_checkpoint_4.pth")
        self.assertFalse(
            torch.all(
                torch.eq(
                    test_model["network.1.0.weight"],
                    new_model["model"]["network.1.0.weight"],
                )))
        new_model = torch.load("checkpoints/model_checkpoint_5.pth")
        self.assertTrue(
            torch.all(
                torch.eq(
                    test_model["network.1.0.weight"],
                    new_model["model"]["network.1.0.weight"],
                )))
Пример #2
0
def train_model(args):

    #global args
    #args = parser.parse_args()

	hidden_size = 128 
	num_classes = 2
	encode_dim = 1000 # using get_frm_output_size()

	L,Y = load_labels(args) 
	data_list = {}
	data_list["dev"] = glob(args.dev + '/la_4ch/*.npy')
	data_list["test"] = glob(args.test + '/la_4ch/*.npy')

	# End Model
	# Create datasets and dataloaders
	dev, test = load_dataset(data_list, Y)
	data_loader = get_data_loader(dev, test, args.batch_size, args.num_workers)
	#print(len(data_loader["dev"])) # 1500 / batch_size
	#print(len(data_loader["test"])) # 1000 / batch_size 
	#import ipdb; ipdb.set_trace()
	
	# Define input encoder
	cnn_encoder = FrameEncoderOC

	if(torch.cuda.is_available()):
		device = 'cuda'
	else:
		device = 'cpu'

	# Define LSTM module
	lstm_module = LSTMModule(
		encode_dim,
		hidden_size,
		bidirectional=False,
		verbose=False,
		lstm_reduction="attention",
		encoder_class=cnn_encoder
		)

	init_kwargs = {
	"layer_out_dims":[hidden_size, num_classes],
	"input_module": lstm_module, 
	"optimizer": "adam",
	"verbose": False,
	"input_batchnorm": False,
	"use_cuda":cuda,
	'seed':args.seed,
	'device':device}

	end_model = EndModel(**init_kwargs)
	
	if not os.path.exists(args.checkpoint_dir):
		os.mkdir(args.checkpoint_dir)
	
	with open(args.checkpoint_dir+'/init_kwargs.pickle', "wb") as f:
		pickle.dump(init_kwargs,f,protocol=pickle.HIGHEST_PROTOCOL)

	# Train end model
	end_model.train_model(
		train_data=data_loader["dev"],
		valid_data=data_loader["test"],
		l2=args.weight_decay,
		lr=args.lr,
		n_epochs=args.n_epochs,
		log_train_every=1,
		verbose=True,
		progress_bar = True,
		loss_weights = [0.55,0.45],
		batchnorm = args.batchnorm,
		middle_dropout = args.dropout,
		checkpoint = False,
		#checkpoint_every = args.n_epochs,
		#checkpoint_best = False,
		#checkpoint_dir = args.checkpoint_dir,
		#validation_metric='f1',
		)

	# evaluate end model
	end_model.score(data_loader["test"], verbose=True,metric=['accuracy','precision', 'recall', 'f1','roc-auc','ndcg'])
	#end_model.score((Xtest,Ytest), verbose=True, metric=['accuracy','precision', 'recall', 'f1'])
	
	# saving model 
	state = {
            "model": end_model.state_dict(),
           # "optimizer": optimizer.state_dict(),
           # "lr_scheduler": lr_scheduler.state_dict() if lr_scheduler else None,
            "score": end_model.score(data_loader["test"],verbose=False,metric=['accuracy','precision', 'recall', 'f1','roc-auc','ndcg'])
        }
	checkpoint_path = f"{args.checkpoint_dir}/best_model.pth"
	torch.save(state, checkpoint_path)