def train_one_epoch(self): self.model.train() loss = AverageMeter() acc = AverageMeter() if self.verbose: iterator = tqdm(self.train_loader) else: iterator = self.train_loader for x, y in iterator: x = x.to(self.device) y = y.to(self.device) output = self.model(x) current_loss = self.loss(output, y) self.optimizer.zero_grad() current_loss.backward() self.optimizer.step() loss.update(current_loss.item()) output = output.detach().cpu().numpy() y = y.cpu().numpy() accuracy = get_accuracy(output, y) acc.update(accuracy, y.shape[0]) # if self.mode == 'crossval': s = ('Training epoch {} | loss: {} - accuracy: ' '{}'.format(self.cur_epoch, round(loss.val, 5), round(acc.val, 5))) print_and_log(self.logger, s)
def train(self): if self.config.mode == 'crosstest': for self.cur_epoch in range(self.config.num_epochs): self.train_one_epoch() s = 'Stopped after ' + str(self.config.num_epochs) + ' epochs' print_and_log(self.logger, s) elif self.config.mode == 'val': stopper = EarlyStopper(self.config.patience, self.config.min_epochs) start_time = time.time() for self.cur_epoch in range(self.config.max_epochs): self.train_one_epoch() acc, _ = self.validate() if start_time is not None: print('{} s/it'.format(round(time.time() - start_time, 3))) start_time = None if stopper.update_and_check(acc, printing=True): s = ('Stopped early with patience ' '{}'.format(self.config.patience)) print_and_log(self.logger, s) break
def run(self): if self.config.mode == 'crosstest': for fold in range(self.config.num_folds): self.initialize_model() s = 'Fold number {}'.format(fold) print_and_log(self.logger, s) (self.train_loader, self.val_loader) = self.mngr.crosstest_ldrs(fold) self.train() self.validate() elif self.config.mode == 'val': self.train_loader, self.val_loader = self.mngr.val_ldrs() self.initialize_model() self.train()
def train_one_epoch(self): self.optimizer.zero_grad() self.model.train() loss = AverageMeter() acc = AverageMeter() if self.verbose: iterator = enumerate(tqdm(self.train_loader)) else: iterator = enumerate(self.train_loader) for i, (x, y) in iterator: attention_mask = (x > 0).float().to(self.device) x = x.to(self.device) y = y.to(self.device) current_loss, output = self.model(x, attention_mask=attention_mask, labels=y) current_loss = current_loss / self.accumulation_steps current_loss.backward() loss.update(current_loss.detach().item()) # MAX_GRAD_NORM = 1.0 # nn.utils.clip_grad_norm_(self.model.parameters(), # MAX_GRAD_NORM) if (i + 1) % self.accumulation_steps == 0: self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() output = output.detach().cpu().numpy() y = y.cpu().numpy() accuracy = get_accuracy(output, y) acc.update(accuracy, y.shape[0]) # del current_loss # del output # del accuracy # del attention_mask # if self.mode == 'crossval': s = ('Training epoch {} | loss: {} - accuracy: ' '{}'.format(self.cur_epoch, round(loss.val, 5), round(acc.val, 5))) print_and_log(self.logger, s)
def __init__(self, config, pct_usage=1, frac=0.5, geo=0.5): self.config = config self.pct_usage = pct_usage self.frac = frac self.geo = geo self.logger = logging.getLogger('BiLSTMAgent') self.cur_epoch = 0 self.loss = CrossEntropyLoss() self.mngr = ProConDataManager(self.config, self.pct_usage, frac, geo) self.device = ( torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')) print('Using ' + str(int(100 * self.pct_usage)) + '% of the dataset.') self.logger.info('Using ' + str(self.pct_usage) + ' of the dataset.') if self.config.aug_mode == 'sr': s = (str(int(100 * self.frac)) + '% of the training data will be ' 'original, the rest augmented.') print_and_log(self.logger, s) s = 'The geometric parameter is ' + str(geo) + '.' print_and_log(self.logger, s)
def validate(self): self.model.eval() with torch.no_grad(): loss = AverageMeter() acc = AverageMeter() for x, y in self.val_loader: attention_mask = (x > 0).float().to(self.device) x = x.to(self.device) y = y.to(self.device) current_loss, output = self.model( x, attention_mask=attention_mask, labels=y) loss.update(current_loss.detach().item()) output = output.detach().cpu().numpy() y = y.cpu().numpy() accuracy = get_accuracy(output, y) acc.update(accuracy, y.shape[0]) s = ('Validating epoch {} | loss: {} - accuracy: ' '{}'.format(self.cur_epoch, round(loss.val, 5), round(acc.val, 5))) print_and_log(self.logger, s) # self.logger.info(s) # print(s) return acc.val, loss.val
def __init__(self, device, logger, data_name, input_length, max_epochs, lr, aug_mode, mode, batch_size, accumulation_steps, small_label=None, small_prop=None, balance_seed=None, undersample=False, pct_usage=None, geo=0.5, split_num=0, verbose=False): assert not (undersample and aug_mode is not None), \ 'Cant undersample and augment' assert sum([mode == 'test-aug', mode == 'save', pct_usage is not None, small_label is not None]) == 1, \ 'Either saving, balancing, or trying on specific percentage' self.logger = logger self.data_name = data_name self.input_length = input_length self.max_epochs = max_epochs self.lr = lr self.aug_mode = aug_mode self.mode = mode self.batch_size = batch_size self.accumulation_steps = accumulation_steps self.small_label = small_label self.small_prop = small_prop self.balance_seed = balance_seed self.undersample = undersample self.pct_usage = pct_usage self.geo = geo self.split_num = split_num self.verbose = verbose mngr_args = [ 'bert', self.input_length, self.aug_mode, self.pct_usage, self.geo, self.batch_size ] mngr_kwargs = { 'small_label': self.small_label, 'small_prop': self.small_prop, 'balance_seed': self.balance_seed, 'undersample': undersample, 'split_num': self.split_num } if data_name == 'sst': self.num_labels = 2 self.mngr = SSTDatasetManager(*mngr_args, **mngr_kwargs) elif data_name == 'subj': self.num_labels = 2 self.mngr = SubjDatasetManager(*mngr_args, **mngr_kwargs) elif data_name == 'sfu': self.num_labels = 2 self.mngr = SFUDatasetManager(*mngr_args, **mngr_kwargs) else: raise ValueError('Data name not recognized.') self.device = ( torch.device(device if torch.cuda.is_available() else 'cpu')) s = ('Model is Bert, dataset is {}, undersample is {},' ' aug mode is {}, geo is {}, pct_usage is {}, small_label is {},' ' small_prop is {}, balance_seed is {}, lr is {},' ' max_epochs is {}, split_num is {}').format( data_name, self.undersample, self.aug_mode, self.geo, self.pct_usage, self.small_label, self.small_prop, self.balance_seed, self.lr, self.max_epochs, self.split_num) print_and_log(self.logger, s)
def __init__(self, device, logger, data_name, input_length, max_epochs, lr, aug_mode, mode, batch_size, small_label=None, small_prop=None, balance_seed=None, undersample=False, pct_usage=None, geo=0.5, split_num=0, verbose=False): assert not (undersample and aug_mode is not None), \ 'Cant undersample and augment' assert sum([mode == 'save', pct_usage is not None, small_label is not None]) == 1, \ 'Either saving, balancing, or trying on specific percentage' # assert sum([mode == 'test', data_name == 'subj']) < 2, \ # 'Must use crosstest on subj' self.logger = logger self.data_name = data_name self.input_length = input_length self.max_epochs = max_epochs self.lr = lr self.aug_mode = aug_mode self.mode = mode self.batch_size = batch_size self.small_label = small_label self.small_prop = small_prop self.balance_seed = balance_seed self.undersample = undersample self.pct_usage = pct_usage self.geo = geo self.split_num = split_num self.verbose = verbose self.loss = CrossEntropyLoss() nlp = spacy.load('en_core_web_md', disable=['parser', 'tagger', 'ner']) nlp.vocab.set_vector(0, vector=np.zeros(nlp.vocab.vectors.shape[1])) self.nlp = nlp mngr_args = [ 'rnn', self.input_length, self.aug_mode, self.pct_usage, self.geo, self.batch_size ] mngr_kwargs = { 'nlp': self.nlp, 'small_label': self.small_label, 'small_prop': self.small_prop, 'balance_seed': self.balance_seed, 'undersample': self.undersample, 'split_num': self.split_num } if data_name == 'sst': self.num_labels = 2 self.mngr = SSTDatasetManager(*mngr_args, **mngr_kwargs) elif data_name == 'subj': self.num_labels = 2 self.mngr = SubjDatasetManager(*mngr_args, **mngr_kwargs) elif data_name == 'sfu': self.num_labels = 2 self.mngr = SFUDatasetManager(*mngr_args, **mngr_kwargs) else: raise ValueError('Data name not recognized.') self.device = ( torch.device(device if torch.cuda.is_available() else 'cpu')) s = ('Model is RNN, dataset is {}, undersample is {},' ' aug mode is {}, geo is {}, pct_usage is {}, small_label is {},' ' small_prop is {}, balance_seed is {}, lr is {},' ' max_epochs is {}, split_num is {}').format( data_name, self.undersample, self.aug_mode, self.geo, self.pct_usage, self.small_label, self.small_prop, self.balance_seed, self.lr, self.max_epochs, self.split_num) print_and_log(self.logger, s)