def fit(self, learning_rate: Tuple[float, float]): # Capture learning errors self.train_val_error = {"train": [], "validation": [], "lr": []} self._init_model( model=self.model_, optimizer=self.optimizer_, criterion=self.criterion_ ) # Setup one cycle policy scheduler = OneCycleLR( optimizer=self.optimizer, max_lr=learning_rate, steps_per_epoch=len(self.train_loader), epochs=self.n_epochs, anneal_strategy="cos", ) # Iterate over epochs for epoch in range(self.n_epochs): # Training set self.model.train() train_loss = 0 for batch_num, samples in enumerate(self.train_loader): # Forward pass, get loss loss = self._forward_pass(samples=samples) train_loss += loss.item() # Zero gradients, perform a backward pass, and update the weights. self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update scheduler self.train_val_error["lr"].append(scheduler.get_lr()[0]) # One cycle scheduler must be called per batch # https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.OneCycleLR scheduler.step() # Append train loss per current epoch train_err = train_loss / batch_num self.train_val_error["train"].append(train_err) # Validation set self.model.eval() validation_loss = 0 for batch_num, samples in enumerate(self.valid_loader): # Forward pass, get loss loss = self._forward_pass(samples=samples) validation_loss += loss.item() # Append validation loss per current epoch val_err = validation_loss / batch_num self.train_val_error["validation"].append(val_err) return pd.DataFrame(data={ 'Train error' : self.train_val_error['train'], 'Validation error': self.train_val_error['validation'] })
def start_train(self, epochs=10, device=device): optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) # scheduler = StepLR(optimizer, step_size=6, gamma=0.1) scheduler = OneCycleLR(optimizer, max_lr=0.1, steps_per_epoch=len(self.train_loader), epochs=epochs) for epoch in range(epochs): # Print Learning Rate print("EPOCH:", epoch + 1, 'LR:', scheduler.get_lr()) self.train_epoch(optimizer, scheduler) self.test_epoch()
def train(ox: Oxentiel, env: gym.Env) -> None: """ Trains a policy gradient model with hyperparams from ``ox``. """ # Set shapes and dimensions for use in type hints. dims.RESOLUTION = ox.resolution dims.BATCH = ox.batch_size dims.ACTS = env.action_space.n shapes.OB = env.observation_space.shape # Make the policy object. ac = ActorCritic(shapes.OB[0], ox.hidden_dim, dims.ACTS) # Make optimizers. policy_optimizer = Adam(ac.pi.parameters(), lr=ox.lr) value_optimizer = Adam(ac.v.parameters(), lr=ox.lr) policy_scheduler = OneCycleLR(policy_optimizer, ox.lr, ox.lr_cycle_steps, pct_start=ox.pct_start) value_scheduler = OneCycleLR(value_optimizer, ox.lr, ox.lr_cycle_steps, pct_start=ox.pct_start) # Create a buffer object to store trajectories. rollouts = RolloutStorage(ox.batch_size, shapes.OB) # Get the initial observation. ob: Array[float, shapes.OB] ob = env.reset() oobs = [] co2s = [] mean_co2 = 0 num_oobs = 0 t_start = time.time() for i in range(ox.iterations): # Sample an action from the policy and estimate the value of current state. act: Array[int, ()] val: Array[float, ()] act, val = get_action(ac, ob) # Step the environment to get new observation, reward, done status, and info. next_ob: Array[float, shapes.OB] rew: int done: bool next_ob, rew, done, info = env.step(int(act)) # Get co2 lbs. co2s.append(info["co2"]) oobs.append(info["oob"]) # Add data for a timestep to the buffer. rollouts.add(ob, act, val, rew) # Don't forget to update the observation. ob = next_ob # If we reached a terminal state, or we completed a batch. if done or rollouts.batch_len == ox.batch_size: # Step 1: Compute advantages and critic targets. # Get episode length. ep_len = rollouts.ep_len dims.EP_LEN = ep_len # Retrieve values and rewards for the current episode. vals: Array[float, ep_len] rews: Array[float, ep_len] vals, rews = rollouts.get_episode_values_and_rewards() mean_rew = np.mean(rews) # The last value should be zero if this is the end of an episode. last_val: float = 0.0 if done else vals[-1] # Compute advantages and rewards-to-go. advs: Array[float, ep_len] = get_advantages(ox, rews, vals, last_val) rtgs: Array[float, ep_len] = get_rewards_to_go(ox, rews) # Record the episode length. if done: rollouts.lens.append(len(advs)) rollouts.rets.append(np.sum(rews)) # Reset the environment. ob = env.reset() mean_co2 = sum(co2s) num_oobs = sum([int(oob) for oob in oobs]) co2s = [] oobs = [] # Step 2: Reset vals and rews in buffer and record computed quantities. rollouts.vals[:] = 0 rollouts.rews[:] = 0 # Record advantages and rewards-to-go. j = rollouts.ep_start assert j + ep_len <= ox.batch_size rollouts.advs[j:j + ep_len] = advs rollouts.rtgs[j:j + ep_len] = rtgs rollouts.ep_start = j + ep_len rollouts.ep_len = 0 # If we completed a batch. if rollouts.batch_len == ox.batch_size: # Get batch data from the buffer. obs: Tensor[float, (ox.batch_size, *shapes.OB)] acts: Tensor[int, (ox.batch_size)] obs, acts, advs, rtgs = rollouts.get_batch() # Run a backward pass on the policy (actor). policy_optimizer.zero_grad() policy_loss = get_policy_loss(ac.pi, obs, acts, advs) policy_loss.backward() policy_optimizer.step() policy_scheduler.step() # Run a backward pass on the value function (critic). value_optimizer.zero_grad() value_loss = get_value_loss(ac.v, obs, rtgs) value_loss.backward() value_optimizer.step() value_scheduler.step() # Reset pointers. rollouts.batch_len = 0 rollouts.ep_start = 0 # Print statistics. lr = policy_scheduler.get_lr() print(f"Iteration: {i + 1} | ", end="") print(f"Time: {time.time() - t_start:.5f} | ", end="") print(f"Total co2: {mean_co2:.5f} | ", end="") print(f"Num OOBs: {num_oobs:.5f} | ", end="") print(f"LR: {lr} | ", end="") print(f"Mean reward for current batch: {mean_rew:.5f}") t_start = time.time() rollouts.rets = [] rollouts.lens = [] if i > 0 and i % ox.save_interval == 0: with open(ox.save_path, "wb") as model_file: torch.save(ac, model_file) print("=== saved model ===")
output_cat = torch.empty(0, dtype=torch.long) for batch in tqdm(testloader, desc=f"Test {epoch}", leave=False): img_cpu, label_cpu = batch img = img_cpu.to(device) label_1hot = [] output = model(img) # collect data for the f1 score lables_cat = torch.cat((lables_cat, label_cpu)) output_cat = torch.cat((output_cat, output.argmax(axis=1).cpu())) # calculate the f1 score test_f1 = f1_score(lables_cat, output_cat, average='macro') # write parameters to log log.add_scalar("Train F1", train_f1, global_step=epoch) log.add_scalar("Test F1", test_f1, global_step=epoch) log.add_scalar("Loss", np.mean(mean_loss), global_step=epoch) log.add_scalar("LR", scheduler.get_lr()[0], global_step=epoch) print(f"{epoch}: " f"train f1 {train_f1*100:.2f}%, " f"test f1 {test_f1*100:.2f}% " f"loss {np.mean(mean_loss):.3f}", flush=True) # close log log.close() # save model model.save("MNIST_SVM_3.pt")
bs = 64 base_lr = 0.03 * bs / 64 optim = torch.optim.SGD(model.parameters(), lr=base_lr) imgs_per_epoch = 64 * 1024 iters_per_epoch = 1024 num_epochs = 30 max_iter = int(num_epochs * iters_per_epoch) print(max_iter) # lr_scheduler = WarmupCosineLrScheduler(optim, max_iter, 0) # lr_scheduler = get_cosine_schedule_with_warmup( # optim, 0, max_iter) lr_scheduler = OneCycleLR(optim, base_lr, total_steps=max_iter) lrs = [] for _ in range(max_iter): lr = lr_scheduler.get_lr()[0] #print(lr) lrs.append(lr) lr_scheduler.step() import matplotlib import matplotlib.pyplot as plt import numpy as np lrs = np.array(lrs) n_lrs = len(lrs) plt.plot(np.arange(n_lrs), lrs) plt.title('3') plt.grid() plt.show()
def train(args, writer): # Build train dataset fields, train_dataset = build_and_cache_dataset(args, mode='train') # Build vocab ID, CATEGORY, NEWS = fields vectors = Vectors(name=args.embed_path, cache=args.cache_dir) # NOTE: use train_dataset to build vocab! NEWS.build_vocab( train_dataset, max_size=args.vocab_size, vectors=vectors, unk_init=torch.nn.init.xavier_normal_, ) CATEGORY.build_vocab(train_dataset) model = TextRNN( vocab_size=len(NEWS.vocab), output_dim=args.num_labels, pad_idx=NEWS.vocab.stoi[NEWS.pad_token], dropout=args.dropout, ) # Init embeddings for model model.embedding.from_pretrained(NEWS.vocab.vectors) bucket_iterator = BucketIterator( train_dataset, batch_size=args.train_batch_size, sort_within_batch=True, shuffle=True, sort_key=lambda x: len(x.news), device=args.device, ) # optimizer, lr_scheduler, criterion model.to(args.device) criterion = nn.CrossEntropyLoss() optimizer = Adam(model.parameters(), lr=args.learning_rate, eps=args.adam_epsilon) scheduler = OneCycleLR(optimizer, max_lr=args.learning_rate * 10, epochs=args.num_train_epochs, steps_per_epoch=len(bucket_iterator)) global_step = 0 model.zero_grad() train_trange = trange(0, args.num_train_epochs, desc="Train epoch") for _ in train_trange: epoch_iterator = tqdm(bucket_iterator, desc='Training') for step, batch in enumerate(epoch_iterator): model.train() news, news_lengths = batch.news category = batch.category preds = model(news, news_lengths) loss = criterion(preds, category) loss.backward() # Logging writer.add_scalar('Train/Loss', loss.item(), global_step) writer.add_scalar('Train/lr', scheduler.get_lr()[0], global_step) # NOTE: Update model, optimizer should update before scheduler optimizer.step() scheduler.step() global_step += 1 # NOTE:Evaluate if args.logging_steps > 0 and global_step % args.logging_steps == 0: results = evaluate(args, model, CATEGORY.vocab, NEWS.vocab) for key, value in results.items(): writer.add_scalar("Eval/{}".format(key), value, global_step) # NOTE: save model if args.save_steps > 0 and global_step % args.save_steps == 0: save_model(args, model, optimizer, scheduler, global_step) writer.close()
def train_main(): args = parse_args() # directory for storing weights and other training related files training_starttime = datetime.now().strftime("%d_%m_%Y-%H_%M_%S-%f") ckpt_dir = os.path.join(args.results_dir, args.dataset, f'checkpoints_{training_starttime}') os.makedirs(ckpt_dir, exist_ok=True) os.makedirs(os.path.join(ckpt_dir, 'confusion_matrices'), exist_ok=True) with open(os.path.join(ckpt_dir, 'args.json'), 'w') as f: json.dump(vars(args), f, sort_keys=True, indent=4) with open(os.path.join(ckpt_dir, 'argsv.txt'), 'w') as f: f.write(' '.join(sys.argv)) f.write('\n') # when using multi scale supervision the label needs to be downsampled. label_downsampling_rates = [8, 16, 32] # data preparation --------------------------------------------------------- data_loaders = prepare_data(args, ckpt_dir) if args.valid_full_res: train_loader, valid_loader, valid_loader_full_res = data_loaders else: train_loader, valid_loader = data_loaders valid_loader_full_res = None cameras = train_loader.dataset.cameras n_classes_without_void = train_loader.dataset.n_classes_without_void if args.class_weighting != 'None': class_weighting = train_loader.dataset.compute_class_weights( weight_mode=args.class_weighting, c=args.c_for_logarithmic_weighting) else: class_weighting = np.ones(n_classes_without_void) # model building ----------------------------------------------------------- model, device = build_model(args, n_classes=n_classes_without_void) if args.freeze > 0: print('Freeze everything but the output layer(s).') for name, param in model.named_parameters(): if 'out' not in name: param.requires_grad = False # loss, optimizer, learning rate scheduler, csvlogger ---------- # loss functions (only loss_function_train is really needed. # The other loss functions are just there to compare valid loss to # train loss) loss_function_train = \ utils.CrossEntropyLoss2d(weight=class_weighting, device=device) pixel_sum_valid_data = valid_loader.dataset.compute_class_weights( weight_mode='linear') pixel_sum_valid_data_weighted = \ np.sum(pixel_sum_valid_data * class_weighting) loss_function_valid = utils.CrossEntropyLoss2dForValidData( weight=class_weighting, weighted_pixel_sum=pixel_sum_valid_data_weighted, device=device) loss_function_valid_unweighted = \ utils.CrossEntropyLoss2dForValidDataUnweighted(device=device) optimizer = get_optimizer(args, model) # in this script lr_scheduler.step() is only called once per epoch lr_scheduler = OneCycleLR(optimizer, max_lr=[i['lr'] for i in optimizer.param_groups], total_steps=args.epochs, div_factor=25, pct_start=0.1, anneal_strategy='cos', final_div_factor=1e4) # load checkpoint if parameter last_ckpt is provided if args.last_ckpt: ckpt_path = os.path.join(ckpt_dir, args.last_ckpt) epoch_last_ckpt, best_miou, best_miou_epoch = \ load_ckpt(model, optimizer, ckpt_path, device) start_epoch = epoch_last_ckpt + 1 else: start_epoch = 0 best_miou = 0 best_miou_epoch = 0 valid_split = valid_loader.dataset.split # build the log keys for the csv log file and for the web logger log_keys = [f'mIoU_{valid_split}'] if args.valid_full_res: log_keys.append(f'mIoU_{valid_split}_full-res') best_miou_full_res = 0 log_keys_for_csv = log_keys.copy() # mIoU for each camera for camera in cameras: log_keys_for_csv.append(f'mIoU_{valid_split}_{camera}') if args.valid_full_res: log_keys_for_csv.append(f'mIoU_{valid_split}_full-res_{camera}') log_keys_for_csv.append('epoch') for i in range(len(lr_scheduler.get_lr())): log_keys_for_csv.append('lr_{}'.format(i)) log_keys_for_csv.extend(['loss_train_total', 'loss_train_full_size']) for rate in label_downsampling_rates: log_keys_for_csv.append('loss_train_down_{}'.format(rate)) log_keys_for_csv.extend([ 'time_training', 'time_validation', 'time_confusion_matrix', 'time_forward', 'time_post_processing', 'time_copy_to_gpu' ]) valid_names = [valid_split] if args.valid_full_res: valid_names.append(valid_split + '_full-res') for valid_name in valid_names: # iou for every class for i in range(n_classes_without_void): log_keys_for_csv.append(f'IoU_{valid_name}_class_{i}') log_keys_for_csv.append(f'loss_{valid_name}') if loss_function_valid_unweighted is not None: log_keys_for_csv.append(f'loss_{valid_name}_unweighted') csvlogger = CSVLogger(log_keys_for_csv, os.path.join(ckpt_dir, 'logs.csv'), append=True) # one confusion matrix per camera and one for whole valid data confusion_matrices = dict() for camera in cameras: confusion_matrices[camera] = \ ConfusionMatrixTensorflow(n_classes_without_void) confusion_matrices['all'] = \ ConfusionMatrixTensorflow(n_classes_without_void) # start training ----------------------------------------------------------- for epoch in range(int(start_epoch), args.epochs): # unfreeze if args.freeze == epoch and args.finetune is None: print('Unfreezing') for param in model.parameters(): param.requires_grad = True logs = train_one_epoch(model, train_loader, device, optimizer, loss_function_train, epoch, lr_scheduler, args.modality, label_downsampling_rates, debug_mode=args.debug) # validation after every epoch ----------------------------------------- miou, logs = validate(model, valid_loader, device, cameras, confusion_matrices, args.modality, loss_function_valid, logs, ckpt_dir, epoch, loss_function_valid_unweighted, debug_mode=args.debug) if args.valid_full_res: miou_full_res, logs = validate(model, valid_loader_full_res, device, cameras, confusion_matrices, args.modality, loss_function_valid, logs, ckpt_dir, epoch, loss_function_valid_unweighted, add_log_key='_full-res', debug_mode=args.debug) logs.pop('time', None) csvlogger.write_logs(logs) # save weights print(miou['all']) save_current_checkpoint = False if miou['all'] > best_miou: best_miou = miou['all'] best_miou_epoch = epoch save_current_checkpoint = True if args.valid_full_res and miou_full_res['all'] > best_miou_full_res: best_miou_full_res = miou_full_res['all'] best_miou_full_res_epoch = epoch save_current_checkpoint = True # don't save weights for the first 10 epochs as mIoU is likely getting # better anyway if epoch >= 10 and save_current_checkpoint is True: save_ckpt(ckpt_dir, model, optimizer, epoch) # save / overwrite latest weights (useful for resuming training) save_ckpt_every_epoch(ckpt_dir, model, optimizer, epoch, best_miou, best_miou_epoch) # write a finish file with best miou values in order overview # training result quickly with open(os.path.join(ckpt_dir, 'finished.txt'), 'w') as f: f.write('best miou: {}\n'.format(best_miou)) f.write('best miou epoch: {}\n'.format(best_miou_epoch)) if args.valid_full_res: f.write(f'best miou full res: {best_miou_full_res}\n') f.write(f'best miou full res epoch: {best_miou_full_res_epoch}\n') print("Training completed ")