def multiply_learning_rate(self, factor): """used for adaptive learning rate adjustment""" new_lr = self.lr * factor for param_group in self.optimizer.param_groups: if new_lr != param_group[ 'lr'] or Config.use_adaptive_learning_rate_adjustment: updated_lr = "updating lr from " + repr( param_group['lr']) + " to " + (repr( param_group['lr'] * factor) if Config.use_adaptive_learning_rate_adjustment else repr(self.lr * factor)) print(updated_lr) with open(DataManager.get_results_file(), 'a+') as f: f.write(updated_lr) f.write('\n') if Config.use_adaptive_learning_rate_adjustment: param_group['lr'] *= factor else: param_group['lr'] = new_lr
def train_epoch(model, train_loader, epoch, test_loader, device, augmentors=None, print_accuracy=False, drop_adaptive_learning_rate=False): """ Run a single train epoch :param model: the network of type torch.nn.Module :param train_loader: the training dataset :param epoch: the current epoch :param test_loader: The test dataset loader :param print_accuracy: True if should test when printing batch info """ model.train() loss = 0 batch_idx = 0 loops = 1 if not augmentors else 1 + len(augmentors) loops = 1 if Config.batch_by_batch else loops if Config.drop_learning_rate: if not Config.use_adaptive_learning_rate_adjustment: learning_rate_coefficient = 1 / pow(Config.drop_factor, math.floor(epoch / Config.drop_period)) model.multiply_learning_rate(learning_rate_coefficient) elif drop_adaptive_learning_rate: model.multiply_learning_rate(1 / Config.drop_factor) for i in range(loops): if i == 0 and not Config.train_on_origonal_data and not Config.batch_by_batch: # skip origInal data in epoch splicing continue for batch_idx, (inputs, targets) in enumerate(train_loader): # batch loop model.optimizer.zero_grad() has_augs = augmentors is not None and len(augmentors) > 0 train_on_aug = has_augs and i >= 1 train_on_aug = train_on_aug or (Config.batch_by_batch and has_augs) if train_on_aug: if not Config.batch_by_batch: augmentor = augmentors[i - 1] if augmentor is None: continue loss += train_batch(model, inputs, targets, device, augmentor=augmentor) else: for augmentor in augmentors: loss += train_batch(model, inputs, targets, device, augmentor=augmentor) model.optimizer.zero_grad() train_on_original = i == 0 train_on_original = train_on_original or Config.batch_by_batch if train_on_original: loss += train_batch(model, inputs, targets, device) if print_epoch_every != -1 and epoch % print_epoch_every == 0: if print_accuracy: test_acc = test(model, test_loader, device, print_acc=False) print("epoch", epoch, "average loss:", loss / batch_idx, "accuracy:", test_acc, "i = ", i) with open(DataManager.get_results_file(), 'a+') as f: f.write(repr(epoch) + ': ' + repr(test_acc) + ', loss: ' + repr(loss / batch_idx)) f.write('\n') model.train() return test_acc else: print("epoch", epoch, "average loss:", loss / batch_idx, "i=", i) end_time = time.time()