def train(): dataset = trainDataset() model = MainModel(dataset.vocab_size()[0], dataset.vocab_size()[1]) #model.load_state_dict(torch.load("model2.pth")) model = model.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, pin_memory=True) for epoch in range(EPOCHS): print(f"EPOCH: {epoch + 1}/{EPOCHS}") losses = [] for idx, data in tqdm(enumerate(dataloader)): outputs = model(data["source"].cuda(), data["target"].cuda(), data["alignment"].cuda()) loss = torch.nn.functional.binary_cross_entropy( outputs.view(-1), data["predictions"].cuda().view(-1).float()) # print(loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() losses.append(loss.detach()) print(f"Mean Loss for Epoch: {epoch} is {sum(losses) / len(losses)}") torch.save(model.state_dict(), f"model_lstm.pth")
def main(train_dir, val_dir, checkpoint_dir, batch_size, image_size=512, num_epochs=10, checkpoint_name=None, num_workers=1, pin_memory=True, log_dir="logs", model_name=None, train_csv=None, val_csv=None): # declare datasets train_ds = DataFolder(root_dir=train_dir, transform=transform(image_size, is_training=True), csv_path=train_csv) val_ds = DataFolder(root_dir=val_dir, transform=transform(image_size, is_training=False), csv_path=val_csv) train_loader = DataLoader(train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True) val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True) #init model model = MainModel(128, model_name) # configure parameter loss_fn = nn.CrossEntropyLoss() model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=1e-4) scaler = torch.cuda.amp.GradScaler() # checkpoint = {'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()} # save_checkpoint(checkpoint, os.path.join(checkpoint_dir, f"checkpoint_initialilze.pth.tar")) # return if checkpoint_name: ckp_path = os.path.join(checkpoint_dir, checkpoint_name) load_checkpoint(torch.load(ckp_path), model, optimizer) check_accuracy(val_loader, model, device) #training for epoch in range(num_epochs): train_fn(train_loader, model, optimizer, loss_fn, scaler, device, epoch, log_dir=log_dir) check_accuracy(val_loader, model, device) checkpoint = { 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() } save_checkpoint( checkpoint, os.path.join(checkpoint_dir, f"checkpoint_{epoch}.pth.tar"))
class Config: def __init__(self, batch_size=32, lr=0.00015, p_horizontalflip=0.4, model_type='ResNet101', training_mode='only_new'): ## INFO ABOUT EXPERIMENT self.logsFileName = 'LOGS' self.logsFileName_finetuning = 'LOGS_finetuning' self.seed = 13 seed_torch(self.seed) if os.path.exists('./Logs/' + self.logsFileName + '.csv'): if training_mode == 'only_new': self.df_logger = Logger(self.logsFileName + '.csv', 'df') self.experiment_name = 'exp{}'.format(len(self.df_logger.logsFile)) + '_end_epoch' self.df_logger.save() elif training_mode == 'finetuning': self.df_logger = Logger(self.logsFileName_finetuning + '.csv', 'df') self.experiment_name = 'exp{}'.format(len(self.df_logger.logsFile)) + '_end_epoch' self.df_logger.save() else: self.experiment_name = 'exp{}'.format(0) + '_end_epoch' self.exper_type = 'data_imgsize_300' self.img_size = 300 # self.img_size_crop = 300 ## MODEL PARAMETERS self.weights_dir = './Model_weights/' self.weights_dir_finetuning = './Model_weights_finetuning/' self.model_type = model_type self.model = MainModel(model_type=self.model_type).model self.pytorch_total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) self.lr = lr self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-5) self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.5, patience=2, verbose=True) self.criterion = nn.MSELoss()#AdaptiveLossFunction(num_dims=1, float_dtype=np.float32, device='cuda:0') # self.num_classes = 5 self.model_param_list = [self.model, self.optimizer, self.scheduler] ## EARLY STOPPING self.early_stopping_patience = 10 self.early_stopping = EarlyStopping(self.early_stopping_patience) self.early_stopping_loss = 'pytorch' #kappa ## TRAINING & VALIDATION SETUP self.num_workers = 16 self.n_epochs = 200 self.batch_size = batch_size self.valid_type = 'holdout' #CV self.valid_size = 0.2 self.n_folds = 5 ## for CV! ## TRANSFORMER AND DATASET self.p_horizontalflip = p_horizontalflip self.data_type = 'new' ## PRINT FREQUENCY self.print_frequency = 50