def __init__(self, model, ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, lr=7e-4, alpha=0.99, epsilon=1e-5, number_updates=int(1e6), lrschedule='linear', use_actor_critic=False, rew_loss_coef=0.0, st_loss_coef=0.0, subtree_loss_coef=0.0, nsteps=5, nenvs=1, tree_depth=0): self.max_grad_norm = max_grad_norm self.use_actor_critic = use_actor_critic self.use_reward_loss = model.predict_rewards and rew_loss_coef > 0 self.rew_loss_coef = rew_loss_coef self.use_st_loss = st_loss_coef > 0 and tree_depth > 0 self.st_loss_coef = st_loss_coef self.subtree_loss_coef = subtree_loss_coef self.use_subtree_loss = subtree_loss_coef > 0 self.model = model self.nsteps = nsteps self.nenvs = nenvs self.batch_size = nsteps * nenvs self.num_actions = model.num_actions self.tree_depth = tree_depth if USE_CUDA: self.model = self.model.cuda() if not self.use_actor_critic: self.target_model = copy.deepcopy(self.model) if USE_CUDA: self.target_model = self.target_model.cuda() self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=lr, alpha=alpha, eps=epsilon) if lrschedule == "linear": self.scheduler = LambdaLR(self.optimizer, lambda step: 1.0 - (step / number_updates)) elif lrschedule == "constant": self.scheduler = LambdaLR(self.optimizer, lambda step: 1.0) else: raise ValueError("lrschedule should be 'linear' or 'constant'") self.step = self.model.step if self.use_actor_critic: self.value = self.model.value self.ent_coef = ent_coef self.vf_coef = vf_coef else: self.value = self.target_model.value
class Learner(object): def __init__(self, model, ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, lr=7e-4, alpha=0.99, epsilon=1e-5, number_updates=int(1e6), lrschedule='linear', use_actor_critic=False, rew_loss_coef=0.0, st_loss_coef=0.0, subtree_loss_coef=0.0, nsteps=5, nenvs=1, tree_depth=0): self.max_grad_norm = max_grad_norm self.use_actor_critic = use_actor_critic self.use_reward_loss = model.predict_rewards and rew_loss_coef > 0 self.rew_loss_coef = rew_loss_coef self.use_st_loss = st_loss_coef > 0 and tree_depth > 0 self.st_loss_coef = st_loss_coef self.subtree_loss_coef = subtree_loss_coef self.use_subtree_loss = subtree_loss_coef > 0 self.model = model self.nsteps = nsteps self.nenvs = nenvs self.batch_size = nsteps * nenvs self.num_actions = model.num_actions self.tree_depth = tree_depth if USE_CUDA: self.model = self.model.cuda() if not self.use_actor_critic: self.target_model = copy.deepcopy(self.model) if USE_CUDA: self.target_model = self.target_model.cuda() self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=lr, alpha=alpha, eps=epsilon) if lrschedule == "linear": self.scheduler = LambdaLR(self.optimizer, lambda step: 1.0 - (step / number_updates)) elif lrschedule == "constant": self.scheduler = LambdaLR(self.optimizer, lambda step: 1.0) else: raise ValueError("lrschedule should be 'linear' or 'constant'") self.step = self.model.step if self.use_actor_critic: self.value = self.model.value self.ent_coef = ent_coef self.vf_coef = vf_coef else: self.value = self.target_model.value def train(self, obs, next_obs, returns, rewards, masks, actions, values): """ :param obs: [batch_size x height x width x channels] observations in NHWC :param next_obs: [batch_size x height x width x channels] one-step next states :param returns: [batch_size] n-step discounted returns with bootstrapped value :param rewards: [batch_size] 1-step rewards :param masks: [batch_size] boolean episode termination mask :param actions: [batch_size] actions taken :param values: [batch_size] predicted state values """ # compute the sequences we need to get back reward predictions action_sequences, reward_sequences, sequence_mask = build_sequences( [torch.from_numpy(actions), torch.from_numpy(rewards)], self.nenvs, self.nsteps, self.tree_depth, return_mask=True) action_sequences = cudify(action_sequences.long().squeeze(-1)) reward_sequences = make_variable(reward_sequences.squeeze(-1)) sequence_mask = make_variable(sequence_mask.squeeze(-1)) Q, V, tree_result = self.model(obs) actions = make_variable(torch.from_numpy(actions).long(), requires_grad=False) returns = make_variable(torch.from_numpy(returns), requires_grad=False) policy_loss, value_loss, reward_loss, state_loss, subtree_loss_np, policy_entropy = 0, 0, 0, 0, 0, 0 if self.use_actor_critic: values = make_variable(torch.from_numpy(values), requires_grad=False) advantages = returns - values probs = F.softmax(Q, dim=-1) log_probs = F.log_softmax(Q, dim=-1) log_probs_taken = log_probs.gather(1, actions.unsqueeze(1)).squeeze() pg_loss = -torch.mean(log_probs_taken * advantages.squeeze()) vf_loss = F.mse_loss(V, returns) entropy = -torch.mean(torch.sum(probs * log_probs, 1)) loss = pg_loss + self.vf_coef * vf_loss - self.ent_coef * entropy policy_loss = pg_loss.data.cpu().numpy() value_loss = vf_loss.data.cpu().numpy() policy_entropy = entropy.data.cpu().numpy() else: Q_taken = Q.gather(1, actions.unsqueeze(1)).squeeze() bellman_loss = F.mse_loss(Q_taken, returns) loss = bellman_loss value_loss = bellman_loss.data.cpu().numpy() if self.use_reward_loss: r_taken = get_paths(tree_result["rewards"], action_sequences, self.batch_size, self.num_actions) rew_loss = F.mse_loss(torch.cat(r_taken, 1), reward_sequences, reduce=False) rew_loss = torch.sum(rew_loss * sequence_mask) / sequence_mask.sum() loss = loss + rew_loss * self.rew_loss_coef reward_loss = rew_loss.data.cpu().numpy() if self.use_st_loss: st_embeddings = tree_result["embeddings"][0] st_targets, st_mask = build_sequences([st_embeddings.data], self.nenvs, self.nsteps, self.tree_depth, return_mask=True, offset=1) st_targets = make_variable(st_targets.view(self.batch_size, -1)) st_mask = make_variable(st_mask.view(self.batch_size, -1)) st_taken = get_paths(tree_result["embeddings"][1:], action_sequences, self.batch_size, self.num_actions) st_taken_cat = torch.cat(st_taken, 1) st_loss = F.mse_loss(st_taken_cat, st_targets, reduce=False) st_loss = torch.sum(st_loss * st_mask) / st_mask.sum() state_loss = st_loss.data.cpu().numpy() loss = loss + st_loss * self.st_loss_coef if self.use_subtree_loss: subtree_taken = get_subtree(tree_result["values"], action_sequences, self.batch_size, self.num_actions) target_subtrees = tree_result["values"][0:-1] subtree_taken_clip = time_shift_tree(subtree_taken, self.nenvs, self.nsteps, -1) target_subtrees_clip = time_shift_tree(target_subtrees, self.nenvs, self.nsteps, 1) subtree_loss = [(s_taken - s_target).pow(2).mean() for (s_taken, s_target) in zip(subtree_taken_clip, target_subtrees_clip)] subtree_loss = sum(subtree_loss) subtree_loss_np = subtree_loss.data.cpu().numpy() loss = loss + subtree_loss * self.subtree_loss_coef self.scheduler.step() self.optimizer.zero_grad() loss.backward() grad_norm = nn.utils.clip_grad_norm(self.model.parameters(), self.max_grad_norm) self.optimizer.step() return policy_loss, value_loss, reward_loss, state_loss, subtree_loss_np, policy_entropy, grad_norm
def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = T.Compose([ ResizeImage(256), T.RandomCrop(224), T.RandomHorizontalFlip(), T.ColorJitter(brightness=0.7, contrast=0.7, saturation=0.7, hue=0.5), T.RandomGrayscale(), T.ToTensor(), normalize ]) val_transform = T.Compose( [ResizeImage(256), T.CenterCrop(224), T.ToTensor(), normalize]) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform, MultipleApply([train_transform, val_transform])) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) # define optimizer and lr scheduler optimizer = Adam(classifier.get_parameters(), args.lr) lr_scheduler = LambdaLR( optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x))**(-args.lr_decay)) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return if args.pretrain is None: # first pretrain the classifier wish source data print("Pretraining the model on source domain.") args.pretrain = logger.get_checkpoint_path('pretrain') pretrain_model = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) pretrain_optimizer = Adam(pretrain_model.get_parameters(), args.pretrain_lr) pretrain_lr_scheduler = LambdaLR( pretrain_optimizer, lambda x: args.pretrain_lr * (1. + args.lr_gamma * float(x))**(-args.lr_decay)) # start pretraining for epoch in range(args.pretrain_epochs): # pretrain for one epoch utils.pretrain(train_source_iter, pretrain_model, pretrain_optimizer, pretrain_lr_scheduler, epoch, args, device) # validate to show pretrain process utils.validate(val_loader, pretrain_model, args, device) torch.save(pretrain_model.state_dict(), args.pretrain) print("Pretraining process is done.") checkpoint = torch.load(args.pretrain, map_location='cpu') classifier.load_state_dict(checkpoint) teacher = EmaTeacher(classifier, alpha=args.alpha) consistent_loss = L2ConsistencyLoss().to(device) class_balance_loss = ClassBalanceLoss(num_classes).to(device) # start training best_acc1 = 0. for epoch in range(args.epochs): print(lr_scheduler.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, classifier, teacher, consistent_loss, class_balance_loss, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, teacher: EmaTeacher, consistent_loss, class_balance_loss, optimizer: Adam, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') cls_losses = AverageMeter('Cls Loss', ':3.2f') cons_losses = AverageMeter('Cons Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, cls_losses, cons_losses, cls_accs, tgt_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() teacher.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) (x_t1, x_t2), labels_t = next(train_target_iter) x_s = x_s.to(device) x_t1 = x_t1.to(device) x_t2 = x_t2.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s, _ = model(x_s) y_t, _ = model(x_t1) y_t_teacher, _ = teacher(x_t2) # classification loss cls_loss = F.cross_entropy(y_s, labels_s) # compute output and mask y_t = F.softmax(y_t, dim=1) y_t_teacher = F.softmax(y_t_teacher, dim=1) max_prob, _ = y_t_teacher.max(dim=1) mask = (max_prob > args.threshold).float() # consistent loss cons_loss = consistent_loss(y_t, y_t_teacher, mask) # balance loss balance_loss = class_balance_loss(y_t) * mask.mean() loss = cls_loss + args.trade_off_cons * cons_loss + args.trade_off_balance * balance_loss # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # update teacher teacher.update() # update statistics cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] cls_losses.update(cls_loss.item(), x_s.size(0)) cons_losses.update(cons_loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_s.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def __init__(self, rank: int, num_actors: int, env_spawner: EnvSpawner, model: torch.nn.Module, optimizer: torch.optim.Optimizer, save_path: str = ".", pg_cost: float = 1., baseline_cost: float = 1., entropy_cost: float = 0.01, discounting: float = 0.99, grad_norm_clipping: float = 40., reward_clipping: bool = True, batchsize_training: int = 4, rollout: int = 80, total_steps: int = -1, max_epoch: int = -1, max_time: float = -1., threads_prefetch: int = 1, threads_inference: int = 1, threads_store: int = 1, render: bool = False, max_gif_length: int = 0, verbose: bool = False, print_interval: int = 10, system_log_interval: int = 1, load_checkpoint: bool = False, checkpoint_interval: int = 10, max_queued_batches: int = 128, max_queued_drops: int = 128, max_queued_stores: int = 1024): self.total_num_envs = num_actors * env_spawner.num_envs self.envs_list = [i for i in range(self.total_num_envs)] super().__init__(rank, num_callees=1, num_callers=num_actors, threads_process=threads_inference, caller_class=agents.Actor, caller_args=[env_spawner], future_keys=self.envs_list) # ATTRIBUTES self._save_path = save_path self._model_path = os.path.join(save_path, 'model') self._pg_cost = pg_cost self._baseline_cost = baseline_cost self._entropy_cost = entropy_cost self._discounting = discounting self._grad_norm_clipping = grad_norm_clipping self._reward_clipping = reward_clipping self._batchsize_training = batchsize_training self._rollout = rollout self._total_steps = total_steps self._max_epoch = max_epoch self._max_time = max_time self._verbose = verbose self._print_interval = print_interval self._system_log_interval = system_log_interval self._checkpoint_interval = checkpoint_interval # COUNTERS self.inference_epoch = 0 self.inference_steps = 0 self.inference_time = 0. self.training_epoch = 0 self.training_steps = 0 self.training_time = 0. self.fetching_time = 0. self.runtime = 0 self.dead_counter = 0 # TORCH self.training_device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") # use 2 gpus if available if torch.cuda.device_count() < 2: self.eval_device = self.training_device else: print("2 GPUs used!") self.eval_device = torch.device("cuda:1") model = DataParallel(model) self.model = model.to(self.training_device) self.eval_model = copy.deepcopy(self.model) self.eval_model = self.eval_model.to(self.eval_device) self.eval_model.eval() self.optimizer = optimizer # define a linear decreasing function for linear scheduler def linear_lambda(epoch): return 1 - min(epoch * rollout * batchsize_training, total_steps) / total_steps self.scheduler = LambdaLR(self.optimizer, linear_lambda) # TOOLS self.recorder = Recorder(save_path=self._save_path, render=render, max_gif_length=max_gif_length) # LOAD CHECKPOINT, IF WANTED if load_checkpoint: self._load_checkpoint(self._model_path) # THREADS self.lock_model = mp.Lock() self.lock_prefetch = mp.Lock() self.shutdown_event = mp.Event() self.queue_drops = mp.Queue(maxsize=max_queued_drops) self.queue_batches = mp.Queue(maxsize=max_queued_batches) self.storing_deque = deque(maxlen=max_queued_stores) # check variables used by _check_dead_queues() self.queue_batches_old = self.queue_batches.qsize() self.queue_drop_off_old = self.queue_drops.qsize() self.queue_rpcs_old = len(self._pending_rpcs) # Create prefetch threads # Not actual threads. Name is chosen due to equal API usage in this code self.prefetch_threads = [ mp.Process(target=self._prefetch, args=(self.queue_drops, self.queue_batches, batchsize_training, self.shutdown_event, self.training_device), daemon=True, name='prefetch_thread_%d' % i) for i in range(threads_prefetch) ] self.storing_threads = [ Thread(target=self._store, daemon=True, name='storing_thread_%d' % i) for i in range(threads_store) ] # spawn trajectory store placeholder_eval_obs = self._build_placeholder_eval_obs(env_spawner) self.trajectory_store = TrajectoryStore(self.envs_list, placeholder_eval_obs, self.eval_device, self.queue_drops, self.recorder, trajectory_length=rollout) # start actors self._start_callers() # start threads and processes for thread in [*self.prefetch_threads, *self.storing_threads]: thread.start()
def test_npg(args=get_args()): env = gym.make(args.task) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) # train_envs = gym.make(args.task) train_envs = SubprocVectorEnv( [lambda: gym.make(args.task) for _ in range(args.training_num)], norm_obs=True) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)], norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, device=args.device) actor = ActorProb(net_a, args.action_shape, max_action=args.max_action, unbounded=True, device=args.device).to(args.device) net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, device=args.device) critic = Critic(net_c, device=args.device).to(args.device) torch.nn.init.constant_(actor.sigma_param._bias, -0.5) for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) # do last policy layer scaling, this will make initial actions have (close to) # 0 mean and std, and will help boost performances, # see https://arxiv.org/abs/2006.05990, Fig.24 for details for m in actor.mu.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) optim = torch.optim.Adam(critic.parameters(), lr=args.lr) lr_scheduler = None if args.lr_decay: # decay learning rate to 0 linearly max_update_num = np.ceil( args.step_per_epoch / args.step_per_collect) * args.epoch lr_scheduler = LambdaLR( optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) def dist(*logits): return Independent(Normal(*logits), 1) policy = NPGPolicy(actor, critic, optim, dist, discount_factor=args.gamma, gae_lambda=args.gae_lambda, reward_normalization=args.rew_norm, action_scaling=True, action_bound_method=args.bound_action_method, lr_scheduler=lr_scheduler, action_space=env.action_space, advantage_normalization=args.norm_adv, optim_critic_iters=args.optim_critic_iters, actor_step_size=args.actor_step_size) # load a previous policy if args.resume_path: policy.load_state_dict( torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector if args.training_num > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) # log t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_npg' log_path = os.path.join(args.logdir, args.task, 'npg', log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = BasicLogger(writer, update_interval=100, train_interval=100) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) if not args.watch: # trainer result = onpolicy_trainer(policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger, test_in_train=False) pprint.pprint(result) # Let's watch its performance! policy.eval() test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) print( f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}' )
def train(model, state, path, annotations, val_path, val_annotations, resize, max_size, jitter, batch_size, iterations, val_iterations, mixed_precision, lr, warmup, milestones, gamma, is_master=True, world=1, use_dali=True, verbose=True, metrics_url=None, logdir=None): 'Train the model on the given dataset' # Prepare model nn_model = model stride = model.stride model = convert_fixedbn_model(model) if torch.cuda.is_available(): model = model.cuda() # Setup optimizer and schedule optimizer = SGD(model.parameters(), lr=lr, weight_decay=0.0001, momentum=0.9) model, optimizer = amp.initialize(model, optimizer, opt_level = 'O2' if mixed_precision else 'O0', keep_batchnorm_fp32 = True, loss_scale = 128.0, verbosity = is_master) if world > 1: model = DistributedDataParallel(model) model.train() if 'optimizer' in state: optimizer.load_state_dict(state['optimizer']) def schedule(train_iter): if warmup and train_iter <= warmup: return 0.9 * train_iter / warmup + 0.1 return gamma ** len([m for m in milestones if m <= train_iter]) scheduler = LambdaLR(optimizer, schedule) # Prepare dataset if verbose: print('Preparing dataset...') data_iterator = (DaliDataIterator if use_dali else DataIterator)( path, jitter, max_size, batch_size, stride, world, annotations, training=True) if verbose: print(data_iterator) if verbose: print(' device: {} {}'.format( world, 'cpu' if not torch.cuda.is_available() else 'gpu' if world == 1 else 'gpus')) print(' batch: {}, precision: {}'.format(batch_size, 'mixed' if mixed_precision else 'full')) print('Training model for {} iterations...'.format(iterations)) # Create TensorBoard writer if logdir is not None: from tensorboardX import SummaryWriter if is_master and verbose: print('Writing TensorBoard logs to: {}'.format(logdir)) writer = SummaryWriter(logdir=logdir) profiler = Profiler(['train', 'fw', 'bw']) iteration = state.get('iteration', 0) while iteration < iterations: cls_losses, box_losses = [], [] for i, (data, target) in enumerate(data_iterator): scheduler.step(iteration) # Forward pass profiler.start('fw') optimizer.zero_grad() cls_loss, box_loss = model([data, target]) del data profiler.stop('fw') # Backward pass profiler.start('bw') with amp.scale_loss(cls_loss + box_loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() # Reduce all losses cls_loss, box_loss = cls_loss.mean().clone(), box_loss.mean().clone() if world > 1: torch.distributed.all_reduce(cls_loss) torch.distributed.all_reduce(box_loss) cls_loss /= world box_loss /= world if is_master: cls_losses.append(cls_loss) box_losses.append(box_loss) if is_master and not isfinite(cls_loss + box_loss): raise RuntimeError('Loss is diverging!\n{}'.format( 'Try lowering the learning rate.')) del cls_loss, box_loss profiler.stop('bw') iteration += 1 profiler.bump('train') if is_master and (profiler.totals['train'] > 60 or iteration == iterations): focal_loss = torch.stack(list(cls_losses)).mean().item() box_loss = torch.stack(list(box_losses)).mean().item() learning_rate = optimizer.param_groups[0]['lr'] if verbose: msg = '[{:{len}}/{}]'.format(iteration, iterations, len=len(str(iterations))) msg += ' focal loss: {:.3f}'.format(focal_loss) msg += ', box loss: {:.3f}'.format(box_loss) msg += ', {:.3f}s/{}-batch'.format(profiler.means['train'], batch_size) msg += ' (fw: {:.3f}s, bw: {:.3f}s)'.format(profiler.means['fw'], profiler.means['bw']) msg += ', {:.1f} im/s'.format(batch_size / profiler.means['train']) msg += ', lr: {:.2g}'.format(learning_rate) print(msg, flush=True) if logdir is not None: writer.add_scalar('focal_loss', focal_loss, iteration) writer.add_scalar('box_loss', box_loss, iteration) writer.add_scalar('learning_rate', learning_rate, iteration) del box_loss, focal_loss if metrics_url: post_metrics(metrics_url, { 'focal loss': mean(cls_losses), 'box loss': mean(box_losses), 'im_s': batch_size / profiler.means['train'], 'lr': learning_rate }) # Save model weights state.update({ 'iteration': iteration, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), }) with ignore_sigint(): nn_model.save(state) profiler.reset() del cls_losses[:], box_losses[:] if val_annotations and (iteration == iterations or iteration % val_iterations == 0): infer(model, val_path, None, resize, max_size, batch_size, annotations=val_annotations, mixed_precision=mixed_precision, is_master=is_master, world=world, use_dali=use_dali, is_validation=True, verbose=False) model.train() if iteration == iterations: break if logdir is not None: writer.close()
def train(self) -> None: r"""Main method for training PPO. Returns: None """ self.envs = construct_envs( self.config, get_env_class(self.config.ENV_NAME) ) ppo_cfg = self.config.RL.PPO self.device = ( torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu") ) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) logger.info( "agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()) ) ) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], ppo_cfg.hidden_size, ) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None episode_rewards = torch.zeros(self.envs.num_envs, 1) episode_counts = torch.zeros(self.envs.num_envs, 1) current_episode_reward = torch.zeros(self.envs.num_envs, 1) window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size) window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) with TensorboardWriter( self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs ) as writer: for update in range(self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES ) for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step( rollouts, current_episode_reward, episode_rewards, episode_counts, ) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time window_episode_reward.append(episode_rewards.clone()) window_episode_counts.append(episode_counts.clone()) losses = [value_loss, action_loss] stats = zip( ["count", "reward"], [window_episode_counts, window_episode_reward], ) deltas = { k: ( (v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item() ) for k, v in stats } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps ) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info( "update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start) ) ) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format( update, env_time, pth_time, count_steps ) ) window_rewards = ( window_episode_reward[-1] - window_episode_reward[0] ).sum() window_counts = ( window_episode_counts[-1] - window_episode_counts[0] ).sum() if window_counts > 0: logger.info( "Average window size {} reward: {:3f}".format( len(window_episode_reward), (window_rewards / window_counts).item(), ) ) else: logger.info("No episodes finish in current window") # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps) ) count_checkpoints += 1 self.envs.close()
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): def lr_lambda(current_step): learning_rate = max(0.0, 1. - (float(current_step) / float(num_training_steps))) learning_rate *= min(1.0, float(current_step) / float(num_warmup_steps)) return learning_rate return LambdaLR(optimizer, lr_lambda, last_epoch)
def train(args, dataset, dev_dataset, model, do_train=True): model.train() dataloader = DataLoader(dataset, shuffle=args.shuffle, batch_size=args.batch_size, num_workers=args.num_workers) args.total_steps = len( dataloader) * args.epoch // args.gradient_accumulation_steps lr_lambda = lambda epoch: 0.9**(epoch) if args.optim == "adam": optimizer = Adam(model.parameters(), lr=args.lr, betas=(args.adam_B1, args.adam_B2), weight_decay=args.weight_decay, eps=args.adam_eps) scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) else: optimizer = Adagrad(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) if args.fp16: model, optimizer = to_fp16(args, model, optimizer) model = to_parallel(args, model) steps = 0 best_score = defaultdict(lambda: [0, 0, 0]) best_epoch = None for epoch in range(1, args.epoch + 1): if not do_train: break outputs = [] tr_loss = [] if epoch >= args.lr_dec_epoch: scheduler.step() for batch in tqdm.tqdm(dataloader, desc=f"TRAIN {epoch}"): loss, logit, *_ = model(inputs=batch["inputs"].to(args.device), labels=batch["label"].to(args.device)) outputs += list(zip(batch["pageid"], logit.cpu().tolist())) if args.n_gpu > 1: loss = loss.mean() if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() tr_loss.append(loss.item()) steps += 1 if steps % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() model.zero_grad() print(f"|LOSS|{sum(tr_loss)/len(tr_loss)}|LR|{scheduler.get_lr()}|") score, _ = dataset.evaluate(outputs) print_score(score["micro ave"]) dev_score, _ = eval(args, dev_dataset, model) print_score(dev_score["micro ave"], desc="DEV") if dev_score["micro ave"][-1] > best_score["micro ave"][-1]: best_epoch = epoch best_score = dev_score best_model_param = model.state_dict() model.freeze = True if best_score is None: best_score, _ = eval(args, dev_dataset, model) else: model.load_state_dict(best_model_param) print_score(best_score["micro ave"], desc="BEST DEV") score, _ = eval(args, dataset, model) print_score(score["micro ave"], desc="BEST(DEV) TRAIN") return model, score, best_score, best_epoch
'weight_decay': cfg.TRAIN.BIAS_DECAY and cfg.TRAIN.WEIGHT_DECAY or 0}] else: params += [{'params': [value], 'lr': lr, 'weight_decay': cfg.TRAIN.WEIGHT_DECAY}] if args.cuda: fasterRCNN.cuda() optimizer = { "adam": lambda: optim.Adam(params, lr=lr, betas=(0.9, 0.999), weight_decay=1e-8, eps=1e-08), "rmsprop": lambda: optim.RMSprop(params, lr=lr, momentum=cfg.TRAIN.MOMENTUM, eps=0.001, weight_decay=1e-8), "sgd": lambda: optim.SGD(params, momentum=cfg.TRAIN.MOMENTUM) }[args.optimizer]() # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epochs * len(dataloader), eta_min=0, last_epoch=-1) # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=10) scheduler = LambdaLR(optimizer, lambda x: (((1 + np.cos(x * np.pi / args.max_epochs)) / 2) ** 1.0) * 0.9 + 0.1) if args.resume: load_name = os.path.join(output_dir, 'faster_rcnn_{}_{}_{}.pth'.format(args.checksession, args.checkepoch, args.checkpoint)) print("loading checkpoint %s" % load_name) checkpoint = torch.load(load_name) args.session = checkpoint['session'] args.start_epoch = checkpoint['epoch'] fasterRCNN.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr = optimizer.param_groups[0]['lr'] if 'pooling_mode' in checkpoint.keys(): cfg.POOLING_MODE = checkpoint['pooling_mode'] print("loaded checkpoint %s" % load_name)
def train(model: torch.nn.Module, train_dl: DataLoader, optimizer: torch.optim.Optimizer, scheduler: LambdaLR, validation_evaluator: ClassificationEvaluator, n_epochs: int, device: AnyStr, log_interval: int = 1, patience: int = 10, neg_class_weight: float = None, model_dir: str = "local", split: str = '') -> torch.nn.Module: best_loss = float('inf') patience_counter = 0 best_f1 = 0.0 weights_found = False loss_fn = torch.nn.CrossEntropyLoss( weight=torch.tensor([neg_class_weight, 1.]).to(device)) # Main loop for ep in range(n_epochs): # Training loop for i, batch in enumerate(tqdm(train_dl)): model.train() optimizer.zero_grad() batch = tuple(t.to(device) for t in batch) input_ids = batch[0] masks = batch[1] labels = batch[2] (logits, ) = model(input_ids, attention_mask=masks) loss = loss_fn(logits.view(-1, 2), labels.view(-1)) loss.backward() optimizer.step() scheduler.step() gc.collect() # Inline evaluation (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) # Saving the best model and early stopping if F1 > best_f1: weights_found = True best_model = model.state_dict() # best_loss = val_loss best_f1 = F1 torch.save(model.state_dict(), f'{model_dir}/model_{split}.pth') patience_counter = 0 else: patience_counter += 1 # Stop training once we have lost patience if patience_counter == patience: break if weights_found == False: print("No good weights found, saving weights from last epoch") # Save one just in case torch.save(model.state_dict(), f'{model_dir}/model_{split}.pth') gc.collect() return best_f1
def fit(self, X, Y): start_time = datetime.now() self.log("==============================================================") self.log("hidden_size = " + str(self.hidden_size)) self.log("num_layers = " + str(self.num_layers)) self.log("batch_size = " + str(self.batch_size)) self.log("num_epochs = " + str(self.num_epochs)) self.log("init_lr = " + str(self.init_lr)) self.log("l2_regu_weight_decay = " + str(self.l2_regu_weight_decay)) self.log("lr_schedule_step_size = " + str(self.lr_schedule_step_size)) self.log("lr_schedule_gamma = " + str(self.lr_schedule_gamma)) self.log("clip_grad_norm_type = " + str(self.clip_grad_norm_type)) self.log("use_class_weights = " + str(self.use_class_weights)) self.log("reverse_input_seq = " + str(self.reverse_input_seq)) self.log("dropout = " + str(self.dropout)) self.log("max_grad_norm = " + str(self.max_grad_norm)) self.log("bidirectional = " + str(self.bidirectional)) self.log("use_all_out = " + str(self.use_all_out)) self.log("use_gru = " + str(self.use_gru)) self.log("--------------------------------------------------------------") # Parameters # X dimension is (batch_size * sequence_length * feature_size) sequence_length = X.shape[1] feature_size = X.shape[2] num_classes = len(np.unique(Y)) # Model model = RNN(feature_size, self.hidden_size, self.num_layers, num_classes, sequence_length, dropout=self.dropout, use_cuda=self.use_cuda, bidirectional=self.bidirectional, use_all_out=self.use_all_out, use_gru=self.use_gru) if self.use_cuda: model.cuda() # Loss function criterion = nn.CrossEntropyLoss() # Compute the weight of each class (because the dataset is imbalanced) if self.use_class_weights: class_weights = float(X.shape[0]) / (num_classes * np.bincount(Y)) class_weights = torch.FloatTensor(class_weights) if self.use_cuda: class_weights = class_weights.cuda() criterion = nn.CrossEntropyLoss(weight=class_weights) # Optimizer optimizer = torch.optim.Adam(model.parameters(), lr=self.init_lr, weight_decay=self.l2_regu_weight_decay) # Learning rate scheduler rule = lambda epoch: self.lr_schedule_gamma ** (epoch // self.lr_schedule_step_size) scheduler = LambdaLR(optimizer, lr_lambda=[rule]) # Reverse the input time sequence if self.reverse_input_seq: X = np.flip(X, axis=1) # Save original training data self.train = {"X":deepcopy(X), "Y":deepcopy(Y)} # Break data into batches num_of_left_overs = self.batch_size - (X.shape[0] % self.batch_size) X = np.append(X, X[0:num_of_left_overs, :, :], 0) Y = np.append(Y, Y[0:num_of_left_overs]) num_of_batches = X.shape[0] // self.batch_size X = np.split(X, num_of_batches, 0) Y = np.split(Y, num_of_batches, 0) # Train the Model for epoch in range(1, self.num_epochs+1): X, Y = shuffle(X, Y) # shuffle batches loss_all = [] # for saving the loss in each step scheduler.step() # adjust learning rate # Loop through all batches for x, y in zip(X, Y): x, y = torch.FloatTensor(x), torch.LongTensor(y) if self.use_cuda: x, y = x.cuda(), y.cuda() x, y = Variable(x), Variable(y) optimizer.zero_grad() # reset gradient outputs = model(x) # forward propagation loss = criterion(outputs, y) # compute loss loss.backward() # backward propagation clip_grad_norm(model.parameters(), self.max_grad_norm, norm_type=self.clip_grad_norm_type) # clip gradient optimizer.step() # optimize loss_all.append(loss.data[0]) # save loss for each step self.model = model # save the model # Print the result for the entire epoch T_tr, P_tr = self.train["Y"], self.predict(self.train["X"]) m_train = computeMetric(T_tr, P_tr, False, flatten=True, simple=True, round_to_decimal=2) if self.test is not None: T_te, P_te = self.test["Y"], self.predict(self.test["X"]) m_test = computeMetric(T_te, P_te, False, flatten=True, simple=True, round_to_decimal=2) lr_now = optimizer.state_dict()["param_groups"][0]["lr"] avg_loss = np.mean(loss_all) cm_names = " ".join(m_train["cm"][0]) cm_train = " ".join(map(lambda x: '%5d'%(x), m_train["cm"][1])) if self.test is not None: cm_test = " ".join(map(lambda x: '%4d'%(x), m_test["cm"][1])) self.log('[%2d/%d], LR: %.8f, Loss: %.8f, [%s], [%s], [%s]' %(epoch, self.num_epochs, lr_now, avg_loss, cm_names, cm_train, cm_test)) else: self.log('[%2d/%d], LR: %.9f, Loss: %.9f, [%s], [%s]' %(epoch, self.num_epochs, lr_now, avg_loss, cm_names, cm_train)) # Log the final result self.log("--------------------------------------------------------------") m_train = computeMetric(T_tr, P_tr, False) for m in m_train: self.log("Metric: " + m) self.log(m_train[m]) if self.test is not None: self.log("--------------------------------------------------------------") m_test = computeMetric(T_te, P_te, False) for m in m_test: self.log("Metric: " + m) self.log(m_test[m]) self.log("--------------------------------------------------------------") self.log("From " + str(datetime.now()) + " to " + str(start_time)) self.log("==============================================================") return self
class Learner(RpcCallee): """Agent that runs inference and learning in parallel. This learning agent implements the reinforcement learning algorithm IMPALA following the SEED RL implementation by Google Brain. During initiation: * Spawns :py:attr:`num_actors` instances of :py:class:`~.agents.Actor`. * Invokes their :py:meth:`~.RpcCaller.loop()` methods. * Creates a :py:class:`~.Recorder`. * Creates a :py:class:`~.TrajectoryStore`. * Starts a continous inference process to answer pending RPCs. * Starts a continous prefetching process to prepare batches of complete trajectories for learning. During runtime: * Runs evaluates observations received from :py:class:`~.agents.Actor` and returns actions. * Stores incomplete trajectories in :py:class:`~.TrajectoryStore`. * Trains a global model from trajectories received from a data prefetching thread. Parameters ---------- rank : `int` Rank given by the RPC group on initiation (as in :py:func:`torch.distributed.rpc.init_rpc`). num_actors : `int` Number of total :py:class:`~.agents.Actor` objects to spawn. env_spawner : :py:class:`~.EnvSpawner` Object that spawns an environment on invoking it's :py:meth:`~.EnvSpawner.spawn()` method. model : :py:class:`torch.nn.Module` A torch model that processes frames as returned by an environment spawned by :py:attr:`env_spawner` optimizer : :py:class:`torch.nn.Module` A torch optimizer that links to :py:attr:`model`. save_path : `str` The root directory for saving data. Default: the current working directory. pg_cost : `float` Policy gradient cost/multiplier. baseline_cost : `float` Baseline cost/multiplier. entropy_cost : `float` Entropy cost/multiplier. grad_norm_clipping : `float` If bigger 0, clips the computed gradient norm to given maximum value. reward_clipping : `bool` Reward clipping. batchsize_training : `int` Number of complete trajectories to gather before learning from them as batch. rollout : `int` Length of rollout used by the IMPALA algorithm. total_steps : `int` Maximum number of environment steps to learn from. max_epoch : `int` Maximum number of training epochs to do. max_time : `int` Maximum time for training. threads_prefetch : `int` The number of threads that shall prefetch data for training. threads_inference : `int` The number of threads that shall perform inference. threads_store : `int` The number of threads that shall store data into trajectory store. render: `bool` Set True, if episodes shall be rendered. max_gif_length: `bool` The maximum number of frames that shall be saved as a single gif. Set to 0 (default), if no limit shall be enforced. verbose : `bool` Set True if system metrics shall be printed at interval set by `print_interval`. print_interval : `int` Interval of training epoch system metrics shall be printed. Set to 0 to surpress printing. system_log_interval : `int` Interval of logging system metrics. load_checkpoint : `bool` Set True if the most checkpoint shall be loaded. checkpoint_interval : `int` Interval of checkpointing. Set to 0 to surpress checkpointing. max_queued_batches: `int` Limits the number of batches that can be queued at once. max_queued_drops: `int` Limits the number of dropped trajectories that can be queued by the trajectory store. max_queued_stores: `int` Limits the number of states that can be queued to be stored. """ def __init__(self, rank: int, num_actors: int, env_spawner: EnvSpawner, model: torch.nn.Module, optimizer: torch.optim.Optimizer, save_path: str = ".", pg_cost: float = 1., baseline_cost: float = 1., entropy_cost: float = 0.01, discounting: float = 0.99, grad_norm_clipping: float = 40., reward_clipping: bool = True, batchsize_training: int = 4, rollout: int = 80, total_steps: int = -1, max_epoch: int = -1, max_time: float = -1., threads_prefetch: int = 1, threads_inference: int = 1, threads_store: int = 1, render: bool = False, max_gif_length: int = 0, verbose: bool = False, print_interval: int = 10, system_log_interval: int = 1, load_checkpoint: bool = False, checkpoint_interval: int = 10, max_queued_batches: int = 128, max_queued_drops: int = 128, max_queued_stores: int = 1024): self.total_num_envs = num_actors * env_spawner.num_envs self.envs_list = [i for i in range(self.total_num_envs)] super().__init__(rank, num_callees=1, num_callers=num_actors, threads_process=threads_inference, caller_class=agents.Actor, caller_args=[env_spawner], future_keys=self.envs_list) # ATTRIBUTES self._save_path = save_path self._model_path = os.path.join(save_path, 'model') self._pg_cost = pg_cost self._baseline_cost = baseline_cost self._entropy_cost = entropy_cost self._discounting = discounting self._grad_norm_clipping = grad_norm_clipping self._reward_clipping = reward_clipping self._batchsize_training = batchsize_training self._rollout = rollout self._total_steps = total_steps self._max_epoch = max_epoch self._max_time = max_time self._verbose = verbose self._print_interval = print_interval self._system_log_interval = system_log_interval self._checkpoint_interval = checkpoint_interval # COUNTERS self.inference_epoch = 0 self.inference_steps = 0 self.inference_time = 0. self.training_epoch = 0 self.training_steps = 0 self.training_time = 0. self.fetching_time = 0. self.runtime = 0 self.dead_counter = 0 # TORCH self.training_device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") # use 2 gpus if available if torch.cuda.device_count() < 2: self.eval_device = self.training_device else: print("2 GPUs used!") self.eval_device = torch.device("cuda:1") model = DataParallel(model) self.model = model.to(self.training_device) self.eval_model = copy.deepcopy(self.model) self.eval_model = self.eval_model.to(self.eval_device) self.eval_model.eval() self.optimizer = optimizer # define a linear decreasing function for linear scheduler def linear_lambda(epoch): return 1 - min(epoch * rollout * batchsize_training, total_steps) / total_steps self.scheduler = LambdaLR(self.optimizer, linear_lambda) # TOOLS self.recorder = Recorder(save_path=self._save_path, render=render, max_gif_length=max_gif_length) # LOAD CHECKPOINT, IF WANTED if load_checkpoint: self._load_checkpoint(self._model_path) # THREADS self.lock_model = mp.Lock() self.lock_prefetch = mp.Lock() self.shutdown_event = mp.Event() self.queue_drops = mp.Queue(maxsize=max_queued_drops) self.queue_batches = mp.Queue(maxsize=max_queued_batches) self.storing_deque = deque(maxlen=max_queued_stores) # check variables used by _check_dead_queues() self.queue_batches_old = self.queue_batches.qsize() self.queue_drop_off_old = self.queue_drops.qsize() self.queue_rpcs_old = len(self._pending_rpcs) # Create prefetch threads # Not actual threads. Name is chosen due to equal API usage in this code self.prefetch_threads = [ mp.Process(target=self._prefetch, args=(self.queue_drops, self.queue_batches, batchsize_training, self.shutdown_event, self.training_device), daemon=True, name='prefetch_thread_%d' % i) for i in range(threads_prefetch) ] self.storing_threads = [ Thread(target=self._store, daemon=True, name='storing_thread_%d' % i) for i in range(threads_store) ] # spawn trajectory store placeholder_eval_obs = self._build_placeholder_eval_obs(env_spawner) self.trajectory_store = TrajectoryStore(self.envs_list, placeholder_eval_obs, self.eval_device, self.queue_drops, self.recorder, trajectory_length=rollout) # start actors self._start_callers() # start threads and processes for thread in [*self.prefetch_threads, *self.storing_threads]: thread.start() # pylint: disable=arguments-differ def _loop(self, waiting_time: float = 5): """Inner loop function of a :py:class:`Learner`. Called by :py:meth:`.RpcCallee.loop()`. This method first pulls a batch in :py:attr:`self.queue_batches`. Then it invokes :py:meth:`_learn_from_batch()` and copies the updated model weights from the learning model to :py:attr:`self.eval_model`. System metrics are passed logged using :py:meth:`~.Recorder.log()`. Finally, it checks for reached shutdown criteria, like :py:attr:`self._total_steps` has been reached. Parameters ---------- waiting_time: `float` Seconds to wait on batches delivered by :py:attr:`self.queue_batches`. """ batch = None try: batch = self.queue_batches.get(timeout=waiting_time) except queue.Empty: pass if batch is not None: training_metrics = self._learn_from_batch( batch, grad_norm_clipping=self._grad_norm_clipping, pg_cost=self._pg_cost, baseline_cost=self._baseline_cost, entropy_cost=self._entropy_cost) # delete Tensors after usage to free memory (see torch multiprocessing) del batch with self.lock_model: self.eval_model.load_state_dict(self.model.state_dict()) self.recorder.log('training', training_metrics) if self._checkpoint_interval > 0: # mus split up to not divide by zero if self.training_epoch % self._checkpoint_interval == 0: # 0 or 1 i = int(self.training_epoch % (2 * self._checkpoint_interval) != 0) self._save_checkpoint(self._model_path, i) # check if queues are dead self._check_dead_queues() # check, if shutdown prerequisites haven been reached self.shutdown = ((self.training_epoch > self._max_epoch > 0) or (self.training_steps > self._total_steps > 0) or (self.get_runtime() > self._max_time > 0) or self.shutdown) if self._loop_iteration == self._system_log_interval and not self.shutdown: system_metrics = self._get_system_metrics() self.recorder.log('system', system_metrics) self._loop_iteration = 0 if self._verbose and (self.training_epoch % self._print_interval == 0): print(pprint.pformat(system_metrics)) # be sure to broadcast or react to shutdown event if self.shutdown_event.is_set(): self.shutdown = True elif self.shutdown: self.shutdown_event.set() def process_batch(self, caller_ids: List[Union[int, str]], *batch: List[dict], **misc: dict) -> Dict[str, torch.Tensor]: """Inner method to process a whole batch at once. Called by :py:meth:`~.RpcCallee._process_batch()`. Before returning the result for the given batch, this method: # . Moves its data to the :py:class:`Learner` device (usually GPU) # . Runs inference on this data # . Invokes :py:meth:`_queue_for_storing()` to put evaluated data on storing queue. Parameters ---------- caller_ids : `list[int]` or `list[str]` List of unique identifiers for callers. batch : `list[dict]` List of inputs for evaluation. misc : `dict` Dict of keyword arguments. Primarily used for metrics in this application. """ # concat tensors for each dict in a batch and move to own device for dictionary in batch: for key, value in dictionary.items(): try: # [T, B, C, H, W] => [1, batchsize, C, H, W] dictionary[key] = torch.cat(value, dim=1).to(self.eval_device) except TypeError: # expected for input dictionaries that are not tensors continue # more arguments could be sotred in batch tuple states = batch[0] # run inference start = time.time() with self.lock_model: inference_output, _ = self.eval_model(states) self.inference_time += time.time() - start # log model state at time of inference inference_output['training_steps'] = torch.zeros_like( states['episode_return']).fill_(self.training_steps) self.inference_steps += states['frame'].shape[1] self.inference_epoch += 1 # add states to store in parallel process. Don't move data via RPC as it shall stay on cuda. states = { k: v.detach() for k, v in { **states, **inference_output }.items() } metrics = misc['metrics'] for i, i_caller_id in enumerate(caller_ids): self._queue_for_storing(i_caller_id, {k: v[0, i] for k, v in states.items()}, metrics[i]) # gather an return results results = { c: inference_output['action'][0][i].view(1, 1).cpu().detach() for i, c in enumerate(caller_ids) } return results def _queue_for_storing(self, caller_id: str, state: dict, metrics: dict): """Wrap for storing data onto the :py:obj:`~self.storing_deque`. Parameters ---------- caller_id: `int` or `str` The caller id of the environments the data belongs to. Necessary for proper storing. state: `dict` An environments state dictionary. metrics: `dict` An actors metrics dictionary. """ while True and not self.shutdown: if len(self.storing_deque) < self.storing_deque.maxlen: self.storing_deque.append((caller_id, state, metrics)) break else: time.sleep(0.001) def _store(self, waiting_time: float = 0.1): """Periodically checks for data in :py:obj:`self.storing_deque` and stores found data into the :py:class:`~.TrajectoryStore`. Intended for use as :py:obj:`multiprocessing.Process`. Parameters ---------- waiting_time: `float` The time in seconds to wait for new data. """ while not self.shutdown_event.is_set(): try: caller_id, state, metrics = self.storing_deque.popleft() except IndexError: # deque empty time.sleep(waiting_time) continue self.trajectory_store.add_to_entry(caller_id, state, metrics) del state, metrics def _learn_from_batch(self, batch: Dict[str, torch.Tensor], grad_norm_clipping: float = 40., pg_cost: float = 1., baseline_cost: float = 0.5, entropy_cost: float = 0.01) -> Dict[str, Any]: """Runs the learning process and updates the internal model. This method: # . Evaluates the given :py:attr:`batch` with the internal learning model. # . Invokes :py:meth:`compute_losses()` to get all components of the loss function. # . Calculates the total loss, using the given cost factors for each component. # . Updates the model by invoking the :py:attr:`self.optimizer`. Parameters ---------- batch : `dict` Dict of stacked tensors of complete trajectories as returned by :py:meth:`_to_batch()`. grad_norm_clipping : `float` If bigger 0, clips the computed gradient norm to given maximum value. pg_cost : `float` Cost/Multiplier for policy gradient loss. baseline_cost : `float` Cost/Multiplier for baseline loss. entropy_cost : `float` Cost/Multiplier for entropy regularization. """ # evaluate training batch batch_length = batch['current_length'].sum().item() learner_outputs, _ = self.model(batch) pg_loss, baseline_loss, entropy_loss = self.compute_losses( batch, learner_outputs, discounting=self._discounting, reward_clipping=self._reward_clipping) total_loss = pg_cost * pg_loss \ + baseline_cost * baseline_loss \ + entropy_cost * entropy_loss self.training_steps += batch_length self.training_epoch += 1 # perform update self.optimizer.zero_grad() total_loss.backward() if grad_norm_clipping > 0: nn.utils.clip_grad_norm_(self.model.parameters(), grad_norm_clipping) self.optimizer.step() self.scheduler.step() return { "runtime": self.get_runtime(), "training_time": self.training_time, "training_epoch": self.training_epoch, "training_steps": self.training_steps, "total_loss": total_loss.detach().cpu().item(), "pg_loss": pg_loss.detach().cpu().item(), "baseline_loss": baseline_loss.detach().cpu().item(), "entropy_loss": entropy_loss.detach().cpu().item(), } @staticmethod def compute_losses( batch: Dict[str, torch.Tensor], learner_outputs: Dict[str, torch.Tensor], discounting: float = 0.99, reward_clipping: bool = True ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Computes and returns the components of IMPALA loss. Calculates policy gradient, baseline and entropy loss using Vtrace for value estimation. See Also -------- * The :py:mod:`.functional.loss` module. * The :py:mod:`.functional.vtrace` module. Parameters ---------- batch : `dict` Dict of stacked tensors of complete trajectories as returned by :py:meth:`_to_batch()`. learner_outputs : `dict` Dict with outputs generated during evaluation within training. discounting : `float` Reward discout factor, must be a positive smaller than 1. reward_clipping : `bool` If set, rewards are clamped between -1 and 1. """ assert 0 < discounting <= 1. # Take final value function slice for bootstrapping. bootstrap_value = learner_outputs["baseline"][-1] # Move from obs[t] -> action[t] to action[t] -> obs[t]. batch = {key: tensor[1:] for key, tensor in batch.items()} learner_outputs = { key: tensor[:-1] for key, tensor in learner_outputs.items() } # clip rewards, if wanted if reward_clipping: batch["reward"] = torch.clamp(batch["reward"], -1, 1) discounts = (~batch["done"]).float() * discounting vtrace_returns = vtrace.from_logits( behavior_policy_logits=batch["policy_logits"], target_policy_logits=learner_outputs["policy_logits"], values=learner_outputs["baseline"], bootstrap_value=bootstrap_value, actions=batch["action"], rewards=batch["reward"], discounts=discounts, ) pg_loss = loss.policy_gradient(learner_outputs["policy_logits"], batch["action"], vtrace_returns.pg_advantages) baseline_loss = F.mse_loss(learner_outputs["baseline"], vtrace_returns.vs, reduction='sum') entropy_loss = loss.entropy(learner_outputs["policy_logits"]) return pg_loss, baseline_loss, entropy_loss @staticmethod def _prefetch(in_queue: mp.Queue, out_queue: mp.Queue, batchsize: int, shutdown_event: mp.Event, target_device, waiting_time=5): """Continuously prefetches complete trajectories dropped by the :py:class:`~.TrajectoryStore` for training. As long as shutdown is not set, this method pulls :py:attr:`batchsize` trajectories from :py:attr:`in_queue`, transforms them into batches using :py:meth:`~_to_batch()` and puts them onto the :py:attr:`out_queue`. This usually runs as an asynchronous :py:obj:`multiprocessing.Process`. Parameters ---------- in_queue: :py:obj:`multiprocessing.Queue` A queue that delivers dropped trajectories from :py:class:`~.TrajectoryStore`. out_queue: :py:obj:`multiprocessing.Queue` A queue that delivers batches to :py:meth:`_loop()`. batchsize: `int` The number of trajectories that shall be processed into a batch. shutdown_event: :py:obj:`multiprocessing.Event` An event that breaks this methods internal loop. target_device: :py:obj:`torch.device` The target device of the batch. waiting_time: `float` Time the methods loop sleeps between each iteration. """ while not shutdown_event.is_set(): try: trajectories = [ in_queue.get(timeout=waiting_time) for _ in range(batchsize) ] except queue.Empty: continue batch = Learner._to_batch(trajectories, target_device) # delete Tensors after usage to free memory (see torch multiprocessing) del trajectories try: out_queue.put(batch) except (AssertionError, ValueError): # queue closed continue # delete Tensors after usage to free memory (see torch multiprocessing) del batch try: del trajectories except UnboundLocalError: # already deleted pass @staticmethod def _to_batch(trajectories: List[dict], target_device) -> Dict[str, torch.Tensor]: """Extracts states from a list of trajectories, returns them as batch. Parameters ---------- trajectories: `list` List of trajectories dropped by :py:class:`~.TrajectoryStore`. target_device: :py:obj:`torch.device` The target device of the batch. """ states = listdict_to_dictlist([t['states'] for t in trajectories]) for key, value in states.items(): # [T, B, C, H, W] => [len(trajectories), batchsize, C, H, W] states[key] = torch.cat(value, dim=1).clone().to(target_device) states['current_length'] = torch.stack( [t['current_length'] for t in trajectories]).clone() return states def _get_system_metrics(self): """Returns the training systems metrics. """ return { "runtime": self.get_runtime(), "trajectories_seen": self.recorder.trajectories_seen, "episodes_seen": self.recorder.episodes_seen, "mean_inference_latency": self.recorder.mean_latency, "fetching_time": self.fetching_time, "inference_time": self.inference_time, "inference_steps": self.inference_steps, "training_time": self.training_time, "training_steps": self.training_steps, "queue_batches": self.queue_batches.qsize(), "queue_drops": self.queue_drops.qsize(), "queue_rpcs": len(self._pending_rpcs), "queue_storing": len(self.storing_deque), } def _save_model(self, path: str, filename: str = 'final_model.pt'): """Save the model at the given path. Parameters ---------- path: `str` A valid path to a directory. filename: `str` The filename the saved model shall have. Defaults to ``final_model.pt``. """ os.makedirs(path, exist_ok=True) save_path = os.path.join(path, filename) torch.save(self.model.state_dict(), save_path) print('Final model saved at %s.' % save_path) def _save_checkpoint(self, path: str, i: int): """Saves a checkpoint at the given path. The filename is fixed to ``checkpoint_i.pt``. ``i`` being an integer. Parameters ---------- path: `str` A valid path to a directory. i: `str` The filenames suffix. This is used to enable multiple consecutive checkpoints. """ with warnings.catch_warnings(): # warning thrown on scheduler.state_dict(): optimizers state should be saved as well. # disable this warning because we do save the optimizers state warnings.simplefilter("ignore", category=UserWarning) model_dict = { 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict() } training_dict = { "_pg_cost": self._pg_cost, "_baseline_cost": self._baseline_cost, "_entropy_cost": self._entropy_cost, "_discounting": self._discounting, "_grad_norm_clipping": self._grad_norm_clipping, "_reward_clipping": self._reward_clipping, "_batchsize_training": self._batchsize_training, "_rollout": self._rollout, "_total_steps": self._total_steps, } counters_dict = { "runtime": self.get_runtime(), "trajectories_seen": self.recorder.trajectories_seen, "episodes_seen": self.recorder.episodes_seen, "training_steps": self.training_steps, "training_epoch": self.training_epoch, "inference_steps": self.inference_steps, } save_dict = {**model_dict, **training_dict, **counters_dict} os.makedirs(path, exist_ok=True) save_path = os.path.join(path, 'checkpoint_%d.pt' % i) torch.save(save_dict, save_path) if self._verbose: print('Made checkpoint after training epoch %d.' % self.training_epoch) def _load_checkpoint(self, path: str): """Loads the latest checkpoint from the given path. Data will be loaded directly into programs components. The files 'checkpoint_0.pt' and 'checkpoint_1.pt' will be checked at the given path. Parameters ---------- path: `str` A valid path to a directory. """ try: checkpoint_0 = torch.load(os.path.join(path, 'checkpoint_0.pt')) except FileNotFoundError: checkpoint_0 = None try: checkpoint_1 = torch.load(os.path.join(path, 'checkpoint_1.pt')) except FileNotFoundError: checkpoint_1 = None if (checkpoint_0 is not None and (checkpoint_1 is None or checkpoint_0['training_epoch'] >= checkpoint_1['training_epoch'])): checkpoint = checkpoint_0 del checkpoint_1 elif ( checkpoint_1 is not None and (checkpoint_0 is None or checkpoint_1['training_epoch'] > checkpoint_0['training_epoch'])): checkpoint = checkpoint_1 del checkpoint_0 else: raise FileNotFoundError('No checkpoints found!') # model self.model.load_state_dict(checkpoint['model_state_dict']) self.eval_model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) with warnings.catch_warnings(): # warning thrown on scheduler.state_dict(): optimizers state should be loaded as well. # disable this warning because we do save the optimizers state warnings.simplefilter("ignore", category=UserWarning) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) # training self._pg_cost = checkpoint['_pg_cost'] self._baseline_cost = checkpoint['_baseline_cost'] self._entropy_cost = checkpoint['_entropy_cost'] self._discounting = checkpoint['_discounting'] self._grad_norm_clipping = checkpoint['_grad_norm_clipping'] self._reward_clipping = checkpoint['_reward_clipping'] self._batchsize_training = checkpoint['_batchsize_training'] self._rollout = checkpoint['_rollout'] # counters self._t_start = time.time() - checkpoint['runtime'] self.recorder.trajectories_seen = checkpoint['trajectories_seen'] self.recorder.episodes_seen = checkpoint['episodes_seen'] self.training_steps = checkpoint['training_steps'] self.training_epoch = checkpoint['training_epoch'] self.inference_steps = checkpoint['inference_steps'] @staticmethod def _build_placeholder_eval_obs( env_spawner: EnvSpawner) -> Dict[str, torch.Tensor]: """Returns a dictionary that mimics an evaluated observation with all values being 0. Parameters ---------- env_spawner: :py:class:`.EnvSpawner` An :py:class:`.EnvSpawner` that holds information about the environment, that can be spawned. """ placeholder_eval_obs = env_spawner.placeholder_obs placeholder_eval_obs['action'] = torch.zeros(1, 1) placeholder_eval_obs['baseline'] = torch.zeros(1, 1) placeholder_eval_obs['policy_logits'] = torch.zeros( 1, 1, env_spawner.env_info['action_space'].n) placeholder_eval_obs['training_steps'] = torch.zeros(1, 1) return placeholder_eval_obs def _check_dead_queues(self, dead_threshold: int = 500): """Checks, if all queues has the same length for a chosen number of sequential times. If so, queues are assumed to be dead. The global shutdown is initiated in this case. Parameters ---------- dead_threshold: `int` The maximum number of consecutive checks, if queues are dead. """ if (self.queue_batches_old == self.queue_batches.qsize()) \ and (self.queue_drop_off_old == self.queue_drops.qsize()) \ and (self.queue_rpcs_old == len(self._pending_rpcs)): self.dead_counter += 1 else: self.dead_counter = 0 self.queue_batches_old = self.queue_batches.qsize() self.queue_drop_off_old = self.queue_drops.qsize() self.queue_rpcs_old = len(self._pending_rpcs) if self.dead_counter > dead_threshold: print("\n==========================================") print("CLOSING DUE TO DEAD QUEUES. (Used STRG+C?)") print("==========================================\n") self.shutdown = True def _cleanup(self, waiting_time=5): """Cleans up after main loop is done. Called by :py:meth:`~.RpcCallee.loop()`. Overwrites and calls :py:meth:`~.RpcCallee._cleanup()`. """ self._save_model(self._model_path) self.runtime = self.get_runtime() self.queue_batches.close() self.queue_drops.close() self.trajectory_store.del_all() self.shutdown_event.set() super()._cleanup() # write last buffers print("Write and empty log buffers.") # pylint: disable=protected-access self.recorder._logger.write_buffers() print("Empty queues.") # Empty queues while self.queue_batches.qsize() > 0: batch = self.queue_batches.get(timeout=waiting_time) del batch while self.queue_drops.qsize() > 0: drop = self.queue_drops.get(timeout=waiting_time) del drop # Remove process to ensure freeing of resources. print("Join threads.") for thread in [*self.prefetch_threads, *self.storing_threads]: try: thread.join(timeout=waiting_time) except RuntimeError: # Timeout, thread died during shutdown pass self.queue_batches.join_thread() self.queue_drops.join_thread() print("Empty CUDA cache.") torch.cuda.empty_cache() # Run garbage collection to ensure freeing of resources. print("Running garbage collection.") gc.collect() if self._verbose: self._report() def _report(self): """Reports data to CLI """ if self.runtime > 0: print("\n============== REPORT ==============") fps = self.inference_steps / self.runtime print("infered", str(self.inference_steps), "steps") print("in", str(self.runtime), "seconds") print("==>", str(fps), "fps") fps = self.training_steps / self.runtime print("trained", str(self.training_steps), "steps") print("in", str(self.runtime), "seconds") print("==>", str(fps), "fps") print("Total inference_time:", str(self.inference_time), "seconds") print("Total training_time:", str(self.training_time), "seconds") print("Total fetching_time:", str(self.fetching_time), "seconds") print("Mean inference latency:", str(self.recorder.mean_latency), "seconds")
def train(args): # device device = torch.device('cuda:{}'.format(args.gpu) if args.gpu >= 0 and torch.cuda.is_available() else 'cpu') if args.gpu >= 0 and torch.cuda.is_available(): cudnn.benchmark = True # dtype if args.type == 'float64': dtype = torch.float64 elif args.type == 'float32': dtype = torch.float32 elif args.type == 'float16': dtype = torch.float16 else: raise ValueError('Wrong type!') # model model, run = get_model(args) num_parameters = sum([l.nelement() for l in model.parameters()]) logging.info(model) logging.info('number of parameters: {}'.format(num_parameters)) # dataset train_loader = get_loader(args.im_path, args.gt_path, args.training_list, args.batch_size, args.n_worker) eval_loader = get_loader(args.im_path, args.gt_path, args.eval_list, 1, args.n_worker, training=False) # loss criterion = nn.CrossEntropyLoss( ignore_index=255) # use a Classification Cross-Entropy loss # to device if args.gpu >= 0 is not None: model = torch.nn.DataParallel(model, [args.gpu]) model.to(device=device, dtype=dtype) criterion.to(device=device, dtype=dtype) # load weight logging.info("=> loading checkpoint '{}'".format(args.pretrained_model)) checkpoint = torch.load(args.pretrained_model, map_location=device) model.module.backbone.load_state_dict(checkpoint['state_dict'], strict=False) logging.info("=> loaded checkpoint '{}'".format(args.pretrained_model)) # load weight if args.ft_model is not None: logging.info("=> loading ft model '{}'".format(args.ft_model)) checkpoint = torch.load(args.ft_model, map_location=device) model.load_state_dict(checkpoint['state_dict'], strict=True) logging.info("=> loaded ft model '{}'".format(args.ft_model)) # optimizer max_iter = args.max_epoch * len(train_loader) optimizer = optim.SGD([{ 'params': get_backbone_params(model), 'lr': args.lr }, { 'params': get_decoder_params(model), 'lr': args.last_layer_lr_mult * args.lr }], lr=args.lr, momentum=args.momentum, weight_decay=args.decay) scheduler = LambdaLR(optimizer, lr_lambda=[ lambda iter: lr_poly(iter, max_iter, args.gamma), lambda iter: lr_poly(iter, max_iter, args.gamma) ]) model.train() if args.freeze_bn: for m in model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() for epoch in trange(args.max_epoch): for batch_idx, (data, target) in enumerate(tqdm(train_loader)): iter_idx = epoch * len(train_loader) + batch_idx data, target = data.to(device=device, dtype=dtype), target.to(device=device) run(model, criterion, optimizer, data, target, scheduler, iter_idx, args) iter_idx = (epoch + 1) * len(train_loader) torch.save( { 'iter': iter_idx, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join(args.exp_path, 'checkpoint_{}.pth.tar'.format(iter_idx))) model.eval() m_iou = evaluate(model, eval_loader, args.n_class, device, dtype, iter_idx, args.writer) model.train() if args.freeze_bn: for m in model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() logging.info('Train Epoch: {} [{}/{} ({:.0f}%)], mIOU: {:.6f}' ', 1x_lr: {}, 10x_lr: {}'.format( epoch, batch_idx, len(train_loader), 100. * batch_idx / len(train_loader), m_iou, optimizer.param_groups[0]['lr'], optimizer.param_groups[1]['lr'])) args.writer.add_scalar('val/m_iou', m_iou, iter_idx)
def trainModel(lr, ks, fm, train_loader, w): kernel_size = ks feature_maps = fm learn_rate = lr momentum_rate = 0.75 # momentum_rate = m cyclic_rate = 120 # total_num_iterations = 80 epochs = 14 net = initialNetGenerator(kernel_size, feature_maps, train_loader) alpha = 0.4 # weight_map = np.array([alpha,1-alpha]) weight_map = getWeightMap(train_loader) # print("Weight Map: ", weight_map) training_parameters = "SGD Learning Rate: {} \n Momentum: {} \n Cycle Length: {} \n Number of epochs: {}\n Weight Map: {}".format( learn_rate, momentum_rate, cyclic_rate, epochs, weight_map) model_parameters = "Kernel Size: {} Initial Feature Maps: {}".format( kernel_size, feature_maps) w.add_text('Training Parameters', training_parameters) w.add_text('Model Parameters', model_parameters) weight_map = tensor_format(torch.FloatTensor(weight_map)) criterion = nn.CrossEntropyLoss(weight=weight_map) # criterion = nn.BCEWithLogitsLoss() # criterion = UnBiasedDiceLoss(fg_weight=fg) # criterion = UnBiasedDiceLoss() # criterion1 = DiceLoss() # criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=learn_rate, momentum=momentum_rate) optimizer2 = optim.SGD(net.parameters(), lr=0.2 * learn_rate, momentum=0.9 * momentum_rate) scheduler = LambdaLR(optimizer, lr_lambda=cosine(cyclic_rate)) scheduler2 = LambdaLR(optimizer, lr_lambda=cosine(cyclic_rate)) count = 0 for epoch in range(epochs): for img, label in train_loader: count += 1 ################### **Feed Foward** ######################### img, label = tensor_format(img), tensor_format(label) output = net(img) output, label = crop(output, label) logInitialCellProb(output, count, w, g_dict_of_images) # logInitialSigmoidProb(output,count,w,g_dict_of_images) ## also works for Dice Loss loss = criterion(output, label) # loss = criterion(*changeForBCEAndDiceLoss(output,label)) # loss1 = criterion1(*changeForBCEAndDiceLoss(output,label)) ################### **Logging** ######################### # w.add_scalar('UnBiasedDiceLossLoss', loss.data[0],count) # w.add_scalar('Overlap',getSoftOverLap(output,label),count) w.add_scalar('Loss', loss.data[0], count) # print("Loss value: {}".format(loss)) acc = score(output, label) # acc = sigmoidScore(output,label) w.add_scalar('Accuracy', float(acc), count) w.add_scalar('Percentage of Dead Neurons', net.final_conv_dead_neurons, count) # print("Accuracy: {}".format(acc)) ################### **Update Back** ######################### # if epoch<34: optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() # else: # optimizer2.zero_grad() # loss.backward() # optimizer2.step() # scheduler2.step() return net
def train(self, model, train_loader, val_loader=None, num_epochs=10, log_nth=0, model_args={}): """ Train a given model with the provided data. Inputs: - model: model object initialized from a torch.nn.Module - train_loader: train data in torch.utils.data.DataLoader - val_loader: val data in torch.utils.data.DataLoader - num_epochs: total number of training epochs - log_nth: log training accuracy and loss every nth iteration """ self.writer = tb.SummaryWriter(self.tb_dir) self.val_writer = tb.SummaryWriter(self.tb_val_dir) # filter out frcnn if this is added to the module parameters = [ param for name, param in model.named_parameters() if 'frcnn' not in name ] optim = self.optim(parameters, **self.optim_args) if self.lr_scheduler_lambda: scheduler = LambdaLR(optim, lr_lambda=self.lr_scheduler_lambda) else: scheduler = None self._reset_histories() iter_per_epoch = len(train_loader) print('START TRAIN.') ############################################################################ # TODO: # # Write your own personal training method for our solver. In Each epoch # # iter_per_epoch shuffled training batches are processed. The loss for # # each batch is stored in self.train_loss_history. Every log_nth iteration # # the loss is logged. After one epoch the training accuracy of the last # # mini batch is logged and stored in self.train_acc_history. # # We validate at the end of each epoch, log the result and store the # # accuracy of the entire validation set in self.val_acc_history. # # # Your logging should like something like: # # ... # # [Iteration 700/4800] TRAIN loss: 1.452 # # [Iteration 800/4800] TRAIN loss: 1.409 # # [Iteration 900/4800] TRAIN loss: 1.374 # # [Epoch 1/5] TRAIN acc/loss: 0.560/1.374 # # [Epoch 1/5] VAL acc/loss: 0.539/1.310 # # ... # ############################################################################ for epoch in range(num_epochs): # TRAINING if scheduler: scheduler.step() print("[*] New learning rate(s): {}".format( scheduler.get_lr())) now = time.time() for i, batch in enumerate(train_loader, 1): #inputs, labels = Variable(batch[0]), Variable(batch[1]) optim.zero_grad() losses = model.sum_losses(batch, **model_args) losses['total_loss'].backward() optim.step() for k, v in losses.items(): if k not in self._losses.keys(): self._losses[k] = [] self._losses[k].append(v.data.cpu().numpy()) if log_nth and i % log_nth == 0: next_now = time.time() print('[Iteration %d/%d] %.3f s/it' % (i + epoch * iter_per_epoch, iter_per_epoch * num_epochs, (next_now - now) / log_nth)) now = next_now for k, v in self._losses.items(): last_log_nth_losses = self._losses[k][-log_nth:] train_loss = np.mean(last_log_nth_losses) print('%s: %.3f' % (k, train_loss)) self.writer.add_scalar(k, train_loss, i + epoch * iter_per_epoch) # VALIDATION if val_loader and log_nth: model.eval() for i, batch in enumerate(val_loader): losses = model.sum_losses(batch, **model_args) for k, v in losses.items(): if k not in self._val_losses.keys(): self._val_losses[k] = [] self._val_losses[k].append(v.data.cpu().numpy()) if i >= log_nth: break model.train() for k, v in self._losses.items(): last_log_nth_losses = self._val_losses[k][-log_nth:] val_loss = np.mean(last_log_nth_losses) self.val_writer.add_scalar(k, val_loss, (epoch + 1) * iter_per_epoch) #blobs_val = data_layer_val.forward() #tracks_val = model.val_predict(blobs_val) #im = plot_tracks(blobs_val, tracks_val) #self.val_writer.add_image('val_tracks', im, (epoch+1) * iter_per_epoch) self.snapshot(model, (epoch + 1) * iter_per_epoch) self._reset_histories() self.writer.close() self.val_writer.close() ############################################################################ # END OF YOUR CODE # ############################################################################ print('FINISH.')
def train(model, epochs, train_loader, val_loader): torch.set_num_threads(args.c) optimizer = Adam(params=model.parameters(), lr=1e-4, weight_decay=1e-6) def lambda1(x): return min((1e-1) * (sqrt(sqrt(sqrt(10)))**min(x, 50)), 1) scheduler = LambdaLR(optimizer, lr_lambda=lambda1) optimizer.zero_grad() gradsum = 0 simple_logger = SimpleLogger( f"{model_folder}/{modelname}_with-masks={with_masks}_map-to-zero={map_to_zero}_no-classes={no_classes}_seq-len={seq_len}_time={datetime.now()}.csv", ['epoch', 'loss', 'val_loss', 'grad_norm', 'learning_rate']) additional_info_dummy = torch.zeros(10) best_val_loss = 1000 for epoch in trange(epochs, desc='epochs'): model.train() # save batch losses epoch_train_loss = [] epoch_val_loss = [] # train on batches for pov, act in tqdm(train_loader, desc='batch_train', position=0, leave=True): # reset lstm_state after each sequence lstm_state = model.get_zero_state(batchsize) # swap batch and seq; swap x and c; swap x and y back. Is this necessary? Be careful in testing! match this operation. pov = pov.transpose(0, 1).transpose(2, 4).transpose(3, 4).contiguous() # move to gpu if not there if not pov.is_cuda or not act.is_cuda: pov, act = pov.to(deviceStr), act.to(deviceStr) #loss, ldict, lstm_state = model.get_loss(pov, additional_info_dummy, additional_info_dummy, lstm_state, # torch.zeros(act.shape, dtype=torch.float32, device=deviceStr), act) prediciton, lstm_state = model.forward(pov, additional_info_dummy, lstm_state) loss = categorical_loss(act, prediciton) loss = loss.sum() loss.backward() grad_norm = clip_grad_norm_(model.parameters(), 10) optimizer.step() optimizer.zero_grad() epoch_train_loss.append(loss.item()) ### Eval ##### with torch.no_grad(): model.eval() for pov, act in tqdm(val_loader, desc='batch_eval', position=0, leave=True): # reset lstm_state lstm_state = model.get_zero_state(batchsize) pov = pov.transpose(0, 1).transpose(2, 4).transpose(3, 4).contiguous() # move to gpu pov, act = pov.to(deviceStr), act.to(deviceStr) # move to gpu if not there if not pov.is_cuda or not act.is_cuda: pov, act = pov.to(deviceStr), act.to(deviceStr) else: print('this is actually useful maybe?') prediciton, lstm_state = model.forward(pov, additional_info_dummy, lstm_state) val_loss = categorical_loss(act, prediciton) val_loss = val_loss.sum() epoch_val_loss.append(val_loss.item()) if (epoch % 5) == 0: print("------------------Saving Model!-----------------------") torch.save( model.state_dict(), f"{model_folder}/{modelname}_with-masks={with_masks}_map-to-zero={map_to_zero}_no-classes={no_classes}_seq-len={seq_len}_epoch={epoch}_time={datetime.now()}.tm" ) if (sum(epoch_train_loss) / len(epoch_train_loss)) < best_val_loss: best_val_loss = (sum(epoch_train_loss) / len(epoch_train_loss)) torch.save( model.state_dict(), f"{model_folder}/{modelname}_with-masks={with_masks}_map-to-zero={map_to_zero}_no-classes={no_classes}_seq-len={seq_len}_epoch={epoch}_time={datetime.now()}.tm" ) print("-------------Logging!!!-------------") simple_logger.log([ epoch, sum(epoch_train_loss) / len(epoch_train_loss), sum(epoch_val_loss) / len(epoch_val_loss), gradsum, float(optimizer.param_groups[0]["lr"]) ]) gradsum = 0 scheduler.step()
m.gamma = TD_gamma else: TD_gamma = args.TD_gamma if args.TD_alpha_final > 0: TD_alpha = args.TD_alpha_final - ( ((args.epochs - 1 - epoch) / (args.epochs - 1))** args.ramping_power) * (args.TD_alpha_final - args.TD_alpha) for m in model.modules(): if hasattr(m, 'alpha'): m.alpha = TD_alpha else: TD_alpha = args.TD_alpha return TD_gamma, TD_alpha scheduler = LambdaLR(optimizer, lr_lambda=[schedule]) # Prepare logging columns = [ 'ep', 'lr', 'tr_loss', 'tr_acc', 'tr_time', 'te_loss', 'te_acc', 'te_time', 'wspar', 'aspar', 'agspar' ] if args.TD_gamma_final > 0 or args.TD_alpha_final > 0: columns += ['gamma', 'alpha'] if args.cg_groups > 1: columns += ['cgspar', 'cgthre'] for epoch in range(args.epochs): time_ep = time.time() TD_gamma, TD_alpha = update_gamma_alpha(epoch) if 'TD' in args.model:
def run_test_multiple(style_weight=10.0, content_weight=1.0, total_variation_weight=0.1, n_epoch=100, batch_size=8, style_path="./data/train_9/"): from nntoolbox.vision.learner import MultipleStylesTransferLearner from nntoolbox.vision.utils import UnlabelledImageDataset, PairedDataset, UnlabelledImageListDataset from nntoolbox.utils import get_device from nntoolbox.callbacks import Tensorboard, MultipleMetricLogger,\ ModelCheckpoint, ToDeviceCallback, ProgressBarCB, MixedPrecisionV2, LRSchedulerCB # from nntoolbox.optim.lr_scheduler import FunctionalLR from torch.optim.lr_scheduler import LambdaLR from src.models import GenericDecoder, MultipleStyleTransferNetwork, \ PixelShuffleDecoder, PixelShuffleDecoderV2, MultipleStyleUNet, SimpleDecoder from torchvision.models import vgg19 from torch.utils.data import DataLoader from torchvision.transforms import Compose, Resize, RandomCrop from torch.optim import Adam mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] print("Begin creating dataset") style_paths_train = ["./data/train_" + str(i) + "/" for i in range(1, 8)] style_paths_val = ["./data/train_8/", "./data/train_9/"] content_images = UnlabelledImageListDataset( "data/train2014/", transform=Compose([Resize(512), RandomCrop((256, 256))])) train_style = UnlabelledImageListDataset( style_paths_train, transform=Compose([Resize(512), RandomCrop((256, 256))])) val_style = UnlabelledImageListDataset( style_paths_val, transform=Compose([Resize(512), RandomCrop((256, 256))])) # img_dim = (128, 128) # # content_images = UnlabelledImageDataset("MiniCOCO/128/", img_dim=img_dim) # # style_images = UnlabelledImageDataset(style_path, img_dim=img_dim) # # # content_images = UnlabelledImageListDataset("data/", img_dim=img_dim) # style_images = UnlabelledImageListDataset("data/train_9/", img_dim=img_dim) print("Begin splitting data") train_size = int(0.80 * len(content_images)) val_size = len(content_images) - train_size train_content, val_content = torch.utils.data.random_split( content_images, [train_size, val_size]) train_dataset = PairedDataset(train_content, train_style) val_dataset = PairedDataset(val_content, val_style) # train_sampler = BatchSampler(RandomSampler(train_dataset), batch_size=8, drop_last=True) train_sampler = RandomSampler(train_dataset, replacement=True, num_samples=8) val_sampler = RandomSampler(val_dataset, replacement=True, num_samples=8) print("Begin creating data dataloaders") dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=8) dataloader_val = DataLoader(val_dataset, sampler=val_sampler, batch_size=8) # print(len(dataloader)) print("Creating models") feature_extractor = FeatureExtractor(model=vgg19(True), fine_tune=False, mean=mean, std=std, device=get_device(), last_layer=20) print("Finish creating feature extractor") decoder = PixelShuffleDecoderV2() # decoder = SimpleDecoder() print("Finish creating decoder") model = MultipleStyleTransferNetwork(encoder=FeatureExtractor( model=vgg19(True), fine_tune=False, mean=mean, std=std, device=get_device(), last_layer=20), decoder=decoder, extracted_feature=20) # model = MultipleStyleUNet( # encoder=FeatureExtractorSequential( # model=vgg19(True), fine_tune=False, # mean=mean, std=std, last_layer=20 # ), # extracted_feature=20 # ) # optimizer = Adam(model.parameters()) optimizer = Adam(model.parameters(), lr=1e-4) lr_scheduler = LRSchedulerCB(scheduler=LambdaLR(optimizer, lr_lambda=lambda iter: 1 / (1.0 + 5e-5 * iter)), timescale='iter') learner = MultipleStylesTransferLearner( dataloader, dataloader_val, model, feature_extractor, optimizer=optimizer, style_layers={1, 6, 11, 20}, total_variation_weight=total_variation_weight, style_weight=style_weight, content_weight=content_weight, device=get_device()) every_iter = eval_every = print_every = compute_num_batch( len(train_style), batch_size) # every_iter = eval_every = print_every = compute_num_batch(len(val_style), batch_size) n_iter = every_iter * n_epoch callbacks = [ ToDeviceCallback(), # MixedPrecisionV2(), Tensorboard(every_iter=every_iter, every_epoch=1), MultipleMetricLogger(iter_metrics=[ "content_loss", "style_loss", "total_variation_loss", "loss" ], print_every=print_every), lr_scheduler, ModelCheckpoint(learner=learner, save_best_only=False, filepath='weights/model.pt'), # ProgressBarCB(range(print_every)) ] learner.learn(n_iter=n_iter, callbacks=callbacks, eval_every=eval_every)
optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer] }] if args.fp16: try: from apex.optimizers import FP16_Optimizer from apex.optimizers import FusedAdam except ImportError: raise ImportError( "lease install apex from https://www.github.com/nvidia/apex to use fp16 training." ) optimizer = FusedAdam(optimizer_grouped_parameters, lr=params.learning_rate, bias_correction=False, max_grad_norm=1.0) scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch)) if args.loss_scale == 0: optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) else: optimizer = Adam(optimizer_grouped_parameters, lr=params.learning_rate) scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch)) # Train and evaluate the model logging.info("Starting training for {} epoch(s)".format(params.epoch_num)) train_and_evaluate(model, train_data, val_data, optimizer, scheduler, params, args.model_dir, args.restore_file)
class _Trainer(object): def __init__(self): self.device = torch.device(cfg.device) self.max_epoch = cfg.max_epoch self.train_dataset = NewDataset(train_set=True) self.train_dataloader = DataLoader( self.train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_worker, collate_fn=self.train_dataset.collate_fn) self.val_dataset = NewDataset(train_set=False) self.val_dataloader = DataLoader( self.val_dataset, batch_size=1, shuffle=True, num_workers=cfg.num_worker, collate_fn=self.val_dataset.collate_fn) self.len_train_dataset = len(self.train_dataset) self.model = build_model(cfg.model) self.optimizer = torch.optim.SGD(self.model.parameters(), lr=cfg.lr_start, momentum=cfg.momentum, weight_decay=cfg.weight_decay) if cfg.linear_lr: lf = lambda x: (1 - x / (cfg.max_epoch - 1)) * (1.0 - 0.2) + 0.2 # linear else: # hyp['lrf']==0.2 lf = one_cycle(1, 0.2, cfg.max_epoch) # cosine 1->hyp['lrf'] self.scheduler = LambdaLR(self.optimizer, lr_lambda=lf) # self.scheduler = adjust_lr_by_wave(self.optimizer, self.max_epoch * self.len_train_dataset, cfg.lr_start, # cfg.lr_end, cfg.warmup) # self.scheduler = adjust_lr_by_loss(self.optimizer,cfg.lr_start,cfg.warmup,self.train_dataloader.num_batches) self.writer = SummaryWriter(cfg.tensorboard_path) self.iter = 0 self.cocoGt = COCO(cfg.test_json) def put_log(self, epoch_index, mean_loss, time_per_iter): print( "[epoch:{}|{}] [iter:{}|{}] time:{}s loss:{} giou_loss:{} conf_loss:{} cls_loss:{} lr:{}" .format(epoch_index + 1, self.max_epoch, self.iter + 1, math.ceil(self.len_train_dataset / cfg.batch_size), round(time_per_iter, 2), round(mean_loss[0], 4), round(mean_loss[1], 4), round(mean_loss[2], 4), round(mean_loss[3], 4), self.optimizer.param_groups[0]['lr'])) step = epoch_index * math.ceil( self.len_train_dataset / cfg.batch_size) + self.iter self.writer.add_scalar("loss", mean_loss[0], global_step=step) self.writer.add_scalar("giou loss", mean_loss[1], global_step=step) self.writer.add_scalar("conf loss", mean_loss[2], global_step=step) self.writer.add_scalar("cls loss", mean_loss[3], global_step=step) self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]['lr'], global_step=step) def train_one_epoch(self, epoch_index, train_loss=None, train_lr=None): mean_loss = [0, 0, 0, 0] self.model.train() for self.iter, train_data in enumerate(self.train_dataloader): start_time = time.time() # self.scheduler.step(epoch_index, # self.len_train_dataset * epoch_index + self.iter / cfg.batch_size) # 调整学习率 # self.scheduler.step(self.len_train_dataset * epoch_index + self.iter + 1,mean_loss[0]) image, target, _ = train_data image = image.to(self.device) output, pred = self.model(image) # 计算loss loss, loss_giou, loss_conf, loss_cls = build_loss(output, target) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.scheduler.step() end_time = time.time() time_per_iter = end_time - start_time # 每次迭代所花时间 loss_items = [ loss.item(), loss_giou.item(), loss_conf.item(), loss_cls.item() ] mean_loss = [ (mean_loss[i] * self.iter + loss_items[i]) / (self.iter + 1) for i in range(4) ] self.put_log(epoch_index, mean_loss, time_per_iter) # 记录训练损失 loss_value = round(mean_loss[0], 4) if isinstance(train_loss, list): train_loss.append(loss_value) now_lr = self.optimizer.param_groups[0]["lr"] if isinstance(train_lr, list): train_lr.append(now_lr) if (epoch_index + 1) % cfg.save_step == 0: checkpoint = { 'epoch': epoch_index, 'model': self.model.state_dict(), 'optimizer': self.optimizer.state_dict() } torch.save( self.model.state_dict(), cfg.checkpoint_save_path + cfg.model + '_' + str(epoch_index + 1) + '.pth') @torch.no_grad() def eval(self, epoch_index, mAP_list=None): n_threads = torch.get_num_threads() # FIXME remove this and make paste_masks_in_image run on the GPU torch.set_num_threads(n_threads) cpu_device = torch.device("cpu") self.model.eval() for ann_idx in self.cocoGt.anns: ann = self.cocoGt.anns[ann_idx] ann['area'] = maskUtils.area(self.cocoGt.annToRLE(ann)) iou_types = 'segm' anns = [] for val_data in self.val_dataloader: image, target, logit = val_data image = image.to(self.device) image_size = image.shape[3] # image.shape[2]==image.shape[3] # resize之后图像的大小 _, pred = self.model(image) # TODO:当前只支持batch_size=1 pred = pred.unsqueeze(0) pred = pred[pred[:, :, 8] > cfg.conf_thresh] if pred.shape[0] == 0: pass else: detections = non_max_suppression(pred.unsqueeze(0), cls_thres=cfg.cls_thresh, nms_thres=cfg.conf_thresh) anns.extend( reorginalize_target(detections, logit, image_size, self.cocoGt)) for ann in anns: ann['segmentation'] = self.cocoGt.annToRLE( ann) # 将polygon形式的segmentation转换RLE形式 cocoDt = self.cocoGt.loadRes(anns) cocoEval = COCOeval(self.cocoGt, cocoDt, iou_types) cocoEval.evaluate() cocoEval.accumulate() cocoEval.summarize() ap_per_category(self.cocoGt, cocoEval, epoch_index) print_txt = cocoEval.stats coco_mAP = print_txt[0] voc_mAP = print_txt[1] if isinstance(mAP_list, list): mAP_list.append(voc_mAP)
def lr_range_test(self, data_loader, end_lr, num_iter=100, step_mode='exp', alpha=0.05, ax=None): # Since the test updates both model and optimizer we need to store # their initial states to restore them in the end previous_states = { 'model': deepcopy(self.model.state_dict()), 'optimizer': deepcopy(self.optimizer.state_dict()) } # Retrieves the learning rate set in the optimizer start_lr = self.optimizer.state_dict()['param_groups'][0]['lr'] # Builds a custom function and corresponding scheduler lr_fn = make_lr_fn(start_lr, end_lr, num_iter) scheduler = LambdaLR(self.optimizer, lr_lambda=lr_fn) # Variables for tracking results and iterations tracking = {'loss': [], 'lr': []} iteration = 0 # If there are more iterations than mini-batches in the data loader, # it will have to loop over it more than once while (iteration < num_iter): # That's the typical mini-batch inner loop for x_batch, y_batch in data_loader: x_batch = x_batch.to(self.device) y_batch = y_batch.to(self.device) # Step 1 yhat = self.model(x_batch) # Step 2 loss = self.loss_fn(yhat, y_batch) # Step 3 loss.backward() # Here we keep track of the losses (smoothed) # and the learning rates tracking['lr'].append(scheduler.get_last_lr()[0]) if iteration == 0: tracking['loss'].append(loss.item()) else: prev_loss = tracking['loss'][-1] smoothed_loss = alpha * loss.item() + (1 - alpha) * prev_loss tracking['loss'].append(smoothed_loss) iteration += 1 # Number of iterations reached if iteration == num_iter: break # Step 4 self.optimizer.step() scheduler.step() self.optimizer.zero_grad() # Restores the original states self.optimizer.load_state_dict(previous_states['optimizer']) self.model.load_state_dict(previous_states['model']) if ax is None: fig, ax = plt.subplots(1, 1, figsize=(6, 4)) else: fig = ax.get_figure() ax.plot(tracking['lr'], tracking['loss']) if step_mode == 'exp': ax.set_xscale('log') ax.set_xlabel('Learning Rate') ax.set_ylabel('Loss') fig.tight_layout() return tracking, fig