class FNN: def __init__(self): ## Device configuration self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def set_data(self, features, targets, D, denom_sq): self.features_np = features self.targets_np = targets self.D_np = D self.inv_denom_sq = denom_sq**-1 def train(self, config): ## Internal config self.config = {} self.config['num_epochs'] = 5000 self.config['n_hidden'] = 2 self.config['hidden_size'] = 40 self.config['batch_size'] = 10 self.config['lr'] = 1e-2 self.config['regularization'] = 1e-10 # Overwrite internal config values given in the external config if config: for key in config.keys(): self.config[key] = config[key] # Assume we're using ray.tune at first self.tuning = True ## Model self.config['input_size'] = self.features_np['train'].shape[1] self.config['output_size'] = self.targets_np['train'].shape[1] self.model = Model(self.config).to(self.device) ## Data loaders self.batch_size = self.config['batch_size'] self.train_loader = data_loader.create_loader( self.features_np['train'], self.targets_np['train'], self.batch_size, True) self.validate_loader = data_loader.create_loader( self.features_np['validate'], self.targets_np['validate'], self.features_np['validate'].shape[0], # use all test samples False) # don't shuffle ## Hyperparameters self.num_epochs = self.config['num_epochs'] self.learning_rate = self.config['lr'] ## Loss and optimizer self.criterion = self.eps_reg_sq self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, eps=1e-8, weight_decay=self.config['regularization']) lambdaLR = lambda epoch: 1 / (1 + 0.005*epoch) self.scheduler = LambdaLR(self.optimizer, lr_lambda=lambdaLR) self.train_start() def train_start(self): ## Train early_stop = False self.D = torch.from_numpy(self.D_np).float().to(self.device) for epoch in range(self.num_epochs): for i, (features, targets) in enumerate(self.train_loader): self.model.train() self.optimizer.zero_grad() # Move tensors to the configured device features = features.to(self.device) targets = targets.to(self.device) # Forward pass outputs = self.model(features) loss = self.criterion(outputs, targets) ** 0.5 if torch.isnan(loss): print('Something went nan, stopping') early_stop = True break # break out of this batch # Backward and optimize loss.backward() self.optimizer.step() if early_stop: break # break out of this epoch self.scheduler.step() if epoch%10==0 or epoch==self.num_epochs-1: validate_loss = self.get_loss(self.validate_loader) train_loss = self.get_loss(self.train_loader) print('eps_reg: Epoch [{}/{}], LR: {:.2e}, Train loss: {:.2e}, Validate loss: {:.2e}' .format(epoch+1, self.num_epochs, self.scheduler.get_lr()[0], train_loss.item()**0.5, validate_loss.item()**0.5)) if self.tuning: try: tune.track.log(mean_loss = validate_loss.item(), episodes_this_iter = 10) except: self.tuning = False return self def eps_reg_sq(self, outputs, targets): return torch.sum((self.D*(targets - outputs)) ** 2) * self.inv_denom_sq / targets.shape[0] def get_loss(self, loader): with torch.no_grad(): self.model.eval() loss = 0.0 for features, targets in loader: features = features.to(self.device) targets = targets.to(self.device) outputs = self.model(features) loss += self.criterion(outputs, targets) return loss/len(loader) def evaluate(self, features): with torch.no_grad(): self.model.eval() output = self.model(torch.tensor(features).float()) u_rb = output.numpy() return u_rb def save(self, model_dir, component): try: path_config = os.path.join(tune.track.trial_dir(),'config') path_state_dict = os.path.join(tune.track.trial_dir(),'state_dict') except: # not tuning path_config = os.path.join(model_dir, 'FNN', component,'config') path_state_dict = os.path.join(model_dir, 'FNN', component,'state_dict') with open(path_config, 'wb+') as f: pickle.dump(self.config, f) torch.save(self.model.state_dict(), path_state_dict) def load(self, model_dir, component): ''' Find and loads the best model from ray.tune analysis results. ''' try: path_analysis = os.path.join(model_dir,'FNN',component) analysis = tune.Analysis(path_analysis) df_temp = analysis.dataframe() idx = df_temp['mean_loss'].idxmin() logdir = df_temp.loc[idx]['logdir'] path_config = os.path.join(logdir,'config') path_state_dict = os.path.join(logdir,'state_dict') except: # no tuning records path_config = os.path.join(model_dir, 'FNN', component,'config') path_state_dict = os.path.join(model_dir, 'FNN', component,'state_dict') with open(path_config, 'rb') as f: config = pickle.load(f) self.model = Model(config).to(self.device) state_dict = torch.load(path_state_dict, map_location=torch.device('cpu')) self.model.load_state_dict(state_dict)
def train_domain_classifier( model: torch.nn.Module, train_dl: DataLoader, optimizer: torch.optim.Optimizer, scheduler: LambdaLR, validation_evaluator: MultiDatasetClassificationEvaluator, n_epochs: int, device: AnyStr, class_weights: List, log_interval: int = 1, patience: int = 10, model_dir: str = "wandb_local", gradient_accumulation: int = 1, domain_name: str = ''): #best_loss = float('inf') best_acc = 0.0 patience_counter = 0 epoch_counter = 0 total = sum(len(dl) for dl in train_dls) loss_fn = torch.nn.CrossEntropyLoss( weight=torch.FloatTensor(class_weights).to(device)) # Main loop while epoch_counter < n_epochs: for i, batch in enumerate(tqdm(train_dl)): model.train() batch = tuple(t.to(device) for t in batch) input_ids = batch[0] masks = batch[1] labels = batch[2] # Testing with random domains to see if any effect #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device) domains = batch[3] logits = model(input_ids, attention_mask=masks)[0] loss = loss_fn(logits, domains) loss = loss / gradient_accumulation #if i % gradient_accumulation == 0: loss.backward() optimizer.step() optimizer.zero_grad() if scheduler is not None: scheduler.step() gc.collect() # Inline evaluation (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) print(f"Validation acc: {acc}") # Saving the best model and early stopping #if val_loss < best_loss: if acc > best_acc: best_model = model.state_dict() best_acc = acc torch.save( model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_domainclassifier_{domain_name}.pth' ) patience_counter = 0 else: patience_counter += 1 # Stop training once we have lost patience if patience_counter == patience: break gc.collect() epoch_counter += 1
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model, interp, criterion, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_s = AverageMeter('Loss (s)', ':3.2f') losses_t = AverageMeter('Loss (t)', ':3.2f') losses_entropy_t = AverageMeter('Entropy (t)', ':3.2f') accuracies_s = Meter('Acc (s)', ':3.2f') accuracies_t = Meter('Acc (t)', ':3.2f') iou_s = Meter('IoU (s)', ':3.2f') iou_t = Meter('IoU (t)', ':3.2f') confmat_s = ConfusionMatrix(model.num_classes) confmat_t = ConfusionMatrix(model.num_classes) progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_s, losses_t, losses_entropy_t, accuracies_s, accuracies_t, iou_s, iou_t], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): optimizer.zero_grad() x_s, label_s = next(train_source_iter) x_t, label_t = next(train_target_iter) x_s = x_s.to(device) label_s = label_s.long().to(device) x_t = x_t.to(device) label_t = label_t.long().to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s = model(x_s) pred_s = interp(y_s) loss_cls_s = criterion(pred_s, label_s) loss_cls_s.backward() y_t = model(x_t) pred_t = interp(y_t) loss_cls_t = criterion(pred_t, label_t) loss_entropy_t = robust_entropy(y_t, args.ita) (args.entropy_weight * loss_entropy_t).backward() # compute gradient and do SGD step optimizer.step() lr_scheduler.step() # measure accuracy and record loss losses_s.update(loss_cls_s.item(), x_s.size(0)) losses_t.update(loss_cls_t.item(), x_s.size(0)) losses_entropy_t.update(loss_entropy_t.item(), x_s.size(0)) confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten()) confmat_t.update(label_t.flatten(), pred_t.argmax(1).flatten()) acc_global_s, acc_s, iu_s = confmat_s.compute() acc_global_t, acc_t, iu_t = confmat_t.compute() accuracies_s.update(acc_s.mean().item()) accuracies_t.update(acc_t.mean().item()) iou_s.update(iu_s.mean().item()) iou_t.update(iu_t.mean().item()) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if visualize is not None: visualize(x_s[0], pred_s[0], label_s[0], "source_{}".format(i)) visualize(x_t[0], pred_t[0], label_t[0], "target_{}".format(i))
params = list(model.named_parameters()) optimizer_grouped_parameters = [ { 'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 } ] optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.lr, bias_correction=False) scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 1/(1 + 0.05*epoch)) ## DATA train_loader, val_loader, test_loader = get_data(args) for epoch in range(args.epochs): train(model, optimizer, train_loader, epoch, args) if val_loader is not None: loss, f1 = evaluate(model, val_loader, args) print('val_loss: {:.5f}, val_f1: {:.5f}'.format(loss, f1)) #print('val_loss: {:.5f}, classification: \n{}'.format(loss, f1)) scheduler.step() if test_loader is not None: loss, f1 = evaluate(model, test_loader, args) print('test_loss: {:.5f}, test_f1: {:.5f}'.format(loss, f1)) #print('test_loss: {:.5f}, classification: \n{}'.format(loss, f1))
def train(self) -> None: r"""Main method for DD-PPO. Returns: None """ self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend) add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() self.world_size = distrib.get_world_size() self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank self.config.SIMULATOR_GPU_ID = self.local_rank # Multiply by the number of simulators to make sure they also get unique seeds self.config.TASK_CONFIG.SEED += (self.world_rank * self.config.NUM_PROCESSES) self.config.freeze() random.seed(self.config.TASK_CONFIG.SEED) np.random.seed(self.config.TASK_CONFIG.SEED) torch.manual_seed(self.config.TASK_CONFIG.SEED) if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") self.envs = construct_envs( self.config, get_env_class(self.config.ENV_NAME), workers_ignore_signals=True, ) ppo_cfg = self.config.RL.PPO if (not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) self.agent.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info("agent number of trainable parameters: {}".format( sum(param.numel() for param in self.agent.parameters() if param.requires_grad))) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) obs_space = self.envs.observation_spaces[0] if self._static_encoder: self._encoder = self.actor_critic.net.visual_encoder obs_space = SpaceDict({ "visual_features": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=self._encoder.output_shape, dtype=np.float32, ), **obs_space.spaces, }) with torch.no_grad(): batch["visual_features"] = self._encoder(batch) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, obs_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_layers=self.actor_critic.net.num_recurrent_layers, ) rollouts.to(self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros(self.envs.num_envs, 1, device=self.device) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1, device=self.device), reward=torch.zeros(self.envs.num_envs, 1, device=self.device), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"]) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_update = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] with (TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) if self.world_rank == 0 else contextlib.suppress()) as writer: for update in range(start_update, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, )) requeue_job() return count_steps_delta = 0 self.agent.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps_delta += delta_steps # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if (step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size): break num_rollouts_done_store.add("num_done", 1) self.agent.train() if self._static_encoder: self._encoder.eval() ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [value_loss, action_loss, count_steps_delta], device=self.device, ) distrib.all_reduce(stats) count_steps += stats[2].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") losses = [ stats[0].item() / self.world_size, stats[1].item() / self.world_size, ] deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps, ) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time), )) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps), ) count_checkpoints += 1 self.envs.close()
def main(args): logger = CompleteLogger(args.log, args.phase) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = T.Compose([ T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=(0.5, 1.)), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) source_dataset = datasets.__dict__[args.source] train_source_dataset = source_dataset(root=args.source_root, transforms=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) target_dataset = datasets.__dict__[args.target] train_target_dataset = target_dataset(root=args.target_root, transforms=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # define networks (both generators and discriminators) netG_S2T = cyclegan.generator.__dict__[args.netG]( ngf=args.ngf, norm=args.norm, use_dropout=False).to(device) netG_T2S = cyclegan.generator.__dict__[args.netG]( ngf=args.ngf, norm=args.norm, use_dropout=False).to(device) netD_S = cyclegan.discriminator.__dict__[args.netD]( ndf=args.ndf, norm=args.norm).to(device) netD_T = cyclegan.discriminator.__dict__[args.netD]( ndf=args.ndf, norm=args.norm).to(device) # create image buffer to store previously generated images fake_S_pool = ImagePool(args.pool_size) fake_T_pool = ImagePool(args.pool_size) # define optimizer and lr scheduler optimizer_G = Adam(itertools.chain(netG_S2T.parameters(), netG_T2S.parameters()), lr=args.lr, betas=(args.beta1, 0.999)) optimizer_D = Adam(itertools.chain(netD_S.parameters(), netD_T.parameters()), lr=args.lr, betas=(args.beta1, 0.999)) lr_decay_function = lambda epoch: 1.0 - max(0, epoch - args.epochs ) / float(args.epochs_decay) lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lr_decay_function) lr_scheduler_D = LambdaLR(optimizer_D, lr_lambda=lr_decay_function) # optionally resume from a checkpoint if args.resume: print("Resume from", args.resume) checkpoint = torch.load(args.resume, map_location='cpu') netG_S2T.load_state_dict(checkpoint['netG_S2T']) netG_T2S.load_state_dict(checkpoint['netG_T2S']) netD_S.load_state_dict(checkpoint['netD_S']) netD_T.load_state_dict(checkpoint['netD_T']) optimizer_G.load_state_dict(checkpoint['optimizer_G']) optimizer_D.load_state_dict(checkpoint['optimizer_D']) lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G']) lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D']) args.start_epoch = checkpoint['epoch'] + 1 if args.phase == 'test': transform = T.Compose([ T.Resize(image_size=args.test_input_size), T.wrapper(cyclegan.transform.Translation)(netG_S2T, device), ]) train_source_dataset.translate(transform, args.translated_root) return # define loss function criterion_gan = cyclegan.LeastSquaresGenerativeAdversarialLoss() criterion_cycle = nn.L1Loss() criterion_identity = nn.L1Loss() # define visualization function tensor_to_image = Compose( [Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ToPILImage()]) def visualize(image, name): """ Args: image (tensor): image in shape 3 x H x W name: name of the saving image """ tensor_to_image(image).save( logger.get_image_path("{}.png".format(name))) # start training for epoch in range(args.start_epoch, args.epochs + args.epochs_decay): logger.set_epoch(epoch) print(lr_scheduler_G.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T, criterion_gan, criterion_cycle, criterion_identity, optimizer_G, optimizer_D, fake_S_pool, fake_T_pool, epoch, visualize, args) # update learning rates lr_scheduler_G.step() lr_scheduler_D.step() # save checkpoint torch.save( { 'netG_S2T': netG_S2T.state_dict(), 'netG_T2S': netG_T2S.state_dict(), 'netD_S': netD_S.state_dict(), 'netD_T': netD_T.state_dict(), 'optimizer_G': optimizer_G.state_dict(), 'optimizer_D': optimizer_D.state_dict(), 'lr_scheduler_G': lr_scheduler_G.state_dict(), 'lr_scheduler_D': lr_scheduler_D.state_dict(), 'epoch': epoch, 'args': args }, logger.get_checkpoint_path(epoch)) if args.translated_root is not None: transform = T.Compose([ T.Resize(image_size=args.test_input_size), T.wrapper(cyclegan.transform.Translation)(netG_S2T, device), ]) train_source_dataset.translate(transform, args.translated_root) logger.close()
def train(self) -> None: r"""Main method for training DD/PPO. Returns: None """ self._init_train() count_checkpoints = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: 1 - self.percent_done(), ) resume_state = load_resume_state(self.config) if resume_state is not None: self.agent.load_state_dict(resume_state["state_dict"]) self.agent.optimizer.load_state_dict(resume_state["optim_state"]) lr_scheduler.load_state_dict(resume_state["lr_sched_state"]) requeue_stats = resume_state["requeue_stats"] self.env_time = requeue_stats["env_time"] self.pth_time = requeue_stats["pth_time"] self.num_steps_done = requeue_stats["num_steps_done"] self.num_updates_done = requeue_stats["num_updates_done"] self._last_checkpoint_percent = requeue_stats[ "_last_checkpoint_percent"] count_checkpoints = requeue_stats["count_checkpoints"] prev_time = requeue_stats["prev_time"] self.running_episode_stats = requeue_stats["running_episode_stats"] self.window_episode_stats.update( requeue_stats["window_episode_stats"]) ppo_cfg = self.config.RL.PPO with (TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) if rank0_only() else contextlib.suppress()) as writer: while not self.is_done(): profiling_wrapper.on_start_step() profiling_wrapper.range_push("train update") if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * ( 1 - self.percent_done()) if rank0_only() and self._should_save_resume_state(): requeue_stats = dict( env_time=self.env_time, pth_time=self.pth_time, count_checkpoints=count_checkpoints, num_steps_done=self.num_steps_done, num_updates_done=self.num_updates_done, _last_checkpoint_percent=self._last_checkpoint_percent, prev_time=(time.time() - self.t_start) + prev_time, running_episode_stats=self.running_episode_stats, window_episode_stats=dict(self.window_episode_stats), ) save_resume_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, ), self.config, ) if EXIT.is_set(): profiling_wrapper.range_pop() # train update self.envs.close() requeue_job() return self.agent.eval() count_steps_delta = 0 profiling_wrapper.range_push("rollouts loop") profiling_wrapper.range_push("_collect_rollout_step") for buffer_index in range(self._nbuffers): self._compute_actions_and_step_envs(buffer_index) for step in range(ppo_cfg.num_steps): is_last_step = (self.should_end_early(step + 1) or (step + 1) == ppo_cfg.num_steps) for buffer_index in range(self._nbuffers): count_steps_delta += self._collect_environment_result( buffer_index) if (buffer_index + 1) == self._nbuffers: profiling_wrapper.range_pop( ) # _collect_rollout_step if not is_last_step: if (buffer_index + 1) == self._nbuffers: profiling_wrapper.range_push( "_collect_rollout_step") self._compute_actions_and_step_envs(buffer_index) if is_last_step: break profiling_wrapper.range_pop() # rollouts loop if self._is_distributed: self.num_rollouts_done_store.add("num_done", 1) ( value_loss, action_loss, dist_entropy, ) = self._update_agent() if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() # type: ignore self.num_updates_done += 1 losses = self._coalesce_post_step( dict(value_loss=value_loss, action_loss=action_loss), count_steps_delta, ) self._training_log(writer, losses, prev_time) # checkpoint model if rank0_only() and self.should_checkpoint(): self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict( step=self.num_steps_done, wall_time=(time.time() - self.t_start) + prev_time, ), ) count_checkpoints += 1 profiling_wrapper.range_pop() # train update self.envs.close()
def train(self) -> None: r"""Main method for DD-PPO SLAM. Returns: None """ ##################################################################### ## init distrib and configuration ##################################################################### self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend ) # self.local_rank = 1 add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore( "rollout_tracker", tcp_store ) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() # server number self.world_size = distrib.get_world_size() self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank # gpu number in one server self.config.SIMULATOR_GPU_ID = self.local_rank print("********************* TORCH_GPU_ID: ", self.config.TORCH_GPU_ID) print("********************* SIMULATOR_GPU_ID: ", self.config.SIMULATOR_GPU_ID) # Multiply by the number of simulators to make sure they also get unique seeds self.config.TASK_CONFIG.SEED += ( self.world_rank * self.config.NUM_PROCESSES ) self.config.freeze() random.seed(self.config.TASK_CONFIG.SEED) np.random.seed(self.config.TASK_CONFIG.SEED) torch.manual_seed(self.config.TASK_CONFIG.SEED) if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") ##################################################################### ## build distrib NavSLAMRLEnv environment ##################################################################### print("#############################################################") print("## build distrib NavSLAMRLEnv environment") print("#############################################################") self.envs = construct_envs( self.config, get_env_class(self.config.ENV_NAME) ) observations = self.envs.reset() print("*************************** observations len:", len(observations)) # semantic process for i in range(len(observations)): observations[i]["semantic"] = observations[i]["semantic"].astype(np.int32) se = list(set(observations[i]["semantic"].ravel())) print(se) # print("*************************** observations type:", observations) # print("*************************** observations type:", observations[0]["map_sum"].shape) # 480*480*23 # print("*************************** observations curr_pose:", observations[0]["curr_pose"]) # [] batch = batch_obs(observations, device=self.device) print("*************************** batch len:", len(batch)) # print("*************************** batch:", batch) # print("************************************* current_episodes:", (self.envs.current_episodes())) ##################################################################### ## init actor_critic agent ##################################################################### print("#############################################################") print("## init actor_critic agent") print("#############################################################") self.map_w = observations[0]["map_sum"].shape[0] self.map_h = observations[0]["map_sum"].shape[1] # print("map_: ", observations[0]["curr_pose"].shape) ppo_cfg = self.config.RL.PPO if ( not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0 ): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(observations, ppo_cfg) self.agent.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info( "agent number of trainable parameters: {}".format( sum( param.numel() for param in self.agent.parameters() if param.requires_grad ) ) ) ##################################################################### ## init Global Rollout Storage ##################################################################### print("#############################################################") print("## init Global Rollout Storage") print("#############################################################") self.num_each_global_step = self.config.RL.SLAMDDPPO.num_each_global_step rollouts = GlobalRolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.obs_space, self.g_action_space, ) rollouts.to(self.device) print('rollouts type:', type(rollouts)) print('--------------------------') # for k in rollouts.keys(): # print("rollouts: {0}".format(rollouts.observations)) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) with torch.no_grad(): step_observation = { k: v[rollouts.step] for k, v in rollouts.observations.items() } _, actions, _, = self.actor_critic.act( step_observation, rollouts.prev_g_actions[0], rollouts.masks[0], ) self.global_goals = [[int(action[0].item() * self.map_w), int(action[1].item() * self.map_h)] for action in actions] # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros( self.envs.num_envs, 1, device=self.device ) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1, device=self.device), reward=torch.zeros(self.envs.num_envs, 1, device=self.device), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size) ) print("*************************** current_episode_reward:", current_episode_reward) print("*************************** running_episode_stats:", running_episode_stats) # print("*************************** window_episode_stats:", window_episode_stats) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) # interrupted_state = load_interrupted_state("/home/cirlab1/userdir/ybg/projects/habitat-api/data/interrup.pth") interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"] ) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_update = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] deif = {} with ( TensorboardWriter( self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs ) if self.world_rank == 0 else contextlib.suppress() ) as writer: for update in range(start_update, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES ) # print("************************************* current_episodes:", type(self.envs.count_episodes())) # print(EXIT.is_set()) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, ), "/home/cirlab1/userdir/ybg/projects/habitat-api/data/interrup.pth" ) print("********************EXIT*********************") requeue_job() return count_steps_delta = 0 self.agent.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_global_rollout_step( rollouts, current_episode_reward, running_episode_stats ) pth_time += delta_pth_time env_time += delta_env_time count_steps_delta += delta_steps # print("************************************* current_episodes:") for i in range(len(self.envs.current_episodes())): # print(" ", self.envs.current_episodes()[i].episode_id," ", self.envs.current_episodes()[i].scene_id," ", self.envs.current_episodes()[i].object_category) if self.envs.current_episodes()[i].scene_id not in deif: deif[self.envs.current_episodes()[i].scene_id]=[int(self.envs.current_episodes()[i].episode_id)] else: deif[self.envs.current_episodes()[i].scene_id].append(int(self.envs.current_episodes()[i].episode_id)) # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if ( step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size ): break num_rollouts_done_store.add("num_done", 1) self.agent.train() if self._static_encoder: self._encoder.eval() ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0 ) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [value_loss, action_loss, count_steps_delta], device=self.device, ) distrib.all_reduce(stats) count_steps += stats[2].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") losses = [ stats[0].item() / self.world_size, stats[1].item() / self.world_size, ] deltas = { k: ( (v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item() ) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps, ) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info( "update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time), ) ) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format( update, env_time, pth_time, count_steps ) ) logger.info( "Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count" ), ) ) # for k in deif: # deif[k] = list(set(deif[k])) # deif[k].sort() # print("deif: k", k, " : ", deif[k]) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps), ) print('=' * 20 + 'Save Model' + '=' * 20) logger.info( "Save Model : {}".format(count_checkpoints) ) count_checkpoints += 1 self.envs.close()
class PPOAgent(BaseAgent): actor: nn.Module critic: nn.Module same_body: float = False def __post_init__(self): move_to([self.actor, self.critic], device=cfg.alg.device) if cfg.alg.vf_loss_type == 'mse': self.val_loss_criterion = nn.MSELoss().to(cfg.alg.device) elif cfg.alg.vf_loss_type == 'smoothl1': self.val_loss_criterion = nn.SmoothL1Loss().to(cfg.alg.device) else: raise TypeError( f'Unknown value loss type: {cfg.alg.vf_loss_type}!') all_params = list(self.actor.parameters()) + list( self.critic.parameters()) # keep unique elements only. The following code works for python >=3.7 # for earlier version of python, u need to use OrderedDict self.all_params = dict.fromkeys(all_params).keys() if (cfg.alg.linear_decay_lr or cfg.alg.linear_decay_clip_range) and \ cfg.alg.max_steps > cfg.alg.max_decay_steps: logger.warning( 'max_steps should not be greater than max_decay_steps.') cfg.alg.max_decay_steps = int(cfg.alg.max_steps * 1.5) logger.warning( f'Resetting max_decay_steps to {cfg.alg.max_decay_steps}!') total_epochs = int( np.ceil(cfg.alg.max_decay_steps / (cfg.alg.num_envs * cfg.alg.episode_steps))) if cfg.alg.linear_decay_clip_range: self.clip_range_decay_rate = cfg.alg.clip_range / float( total_epochs) p_lr_lambda = partial(linear_decay_percent, total_epochs=total_epochs) optim_args = dict(lr=cfg.alg.policy_lr, weight_decay=cfg.alg.weight_decay) if not cfg.alg.sgd: optim_args['amsgrad'] = cfg.alg.use_amsgrad optim_func = optim.Adam else: optim_args['nesterov'] = True if cfg.alg.momentum > 0 else False optim_args['momentum'] = cfg.alg.momentum optim_func = optim.SGD if self.same_body: optim_args['params'] = self.all_params else: optim_args['params'] = [{ 'params': self.actor.parameters(), 'lr': cfg.alg.policy_lr }, { 'params': self.critic.parameters(), 'lr': cfg.alg.value_lr }] self.optimizer = optim_func(**optim_args) if self.same_body: self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=[p_lr_lambda]) else: v_lr_lambda = partial(linear_decay_percent, total_epochs=total_epochs) self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=[p_lr_lambda, v_lr_lambda]) @torch.no_grad() def get_action(self, ob, sample=True, *args, **kwargs): self.eval_mode() if type(ob) is dict: t_ob = { key: torch_float(ob[key], device=cfg.alg.device) for key in ob } else: t_ob = torch_float(ob, device=cfg.alg.device) act_dist, val = self.get_act_val(t_ob) action = action_from_dist(act_dist, sample=sample) log_prob = action_log_prob(action, act_dist) entropy = action_entropy(act_dist, log_prob) action_info = dict(log_prob=torch_to_np(log_prob), entropy=torch_to_np(entropy), val=torch_to_np(val)) return torch_to_np(action), action_info def get_act_val(self, ob, *args, **kwargs): if type(ob) is dict: ob = { key: torch_float(ob[key], device=cfg.alg.device) for key in ob } else: ob = torch_float(ob, device=cfg.alg.device) act_dist, body_out = self.actor(ob) if self.same_body: val, body_out = self.critic(body_x=body_out) else: val, body_out = self.critic(x=ob) val = val.squeeze(-1) return act_dist, val @torch.no_grad() def get_val(self, ob, *args, **kwargs): self.eval_mode() if type(ob) is dict: ob = { key: torch_float(ob[key], device=cfg.alg.device) for key in ob } else: ob = torch_float(ob, device=cfg.alg.device) val, body_out = self.critic(x=ob) val = val.squeeze(-1) return val def optimize(self, data, *args, **kwargs): pre_res = self.optim_preprocess(data) processed_data = pre_res processed_data['entropy'] = torch.mean(processed_data['entropy']) loss_res = self.cal_loss(**processed_data) loss, pg_loss, vf_loss, ratio = loss_res self.optimizer.zero_grad() loss.backward() grad_norm = clip_grad(self.all_params, cfg.alg.max_grad_norm) self.optimizer.step() with torch.no_grad(): approx_kl = 0.5 * torch.mean( torch.pow( processed_data['old_log_prob'] - processed_data['log_prob'], 2)) clip_frac = np.mean( np.abs(torch_to_np(ratio) - 1.0) > cfg.alg.clip_range) optim_info = dict(pg_loss=pg_loss.item(), vf_loss=vf_loss.item(), total_loss=loss.item(), entropy=processed_data['entropy'].item(), approx_kl=approx_kl.item(), clip_frac=clip_frac) optim_info['grad_norm'] = grad_norm return optim_info def optim_preprocess(self, data): self.train_mode() for key, val in data.items(): data[key] = torch_float(val, device=cfg.alg.device) ob = data['ob'] state = data['state'] action = data['action'] ret = data['ret'] adv = data['adv'] old_log_prob = data['log_prob'] old_val = data['val'] act_dist, val = self.get_act_val({"ob": ob, "state": state}) log_prob = action_log_prob(action, act_dist) entropy = action_entropy(act_dist, log_prob) if not all([x.ndim == 1 for x in [val, entropy, log_prob]]): raise ValueError('val, entropy, log_prob should be 1-dim!') processed_data = dict(val=val, old_val=old_val, ret=ret, log_prob=log_prob, old_log_prob=old_log_prob, adv=adv, entropy=entropy) return processed_data def cal_loss(self, val, old_val, ret, log_prob, old_log_prob, adv, entropy): vf_loss = self.cal_val_loss(val=val, old_val=old_val, ret=ret) ratio = torch.exp(log_prob - old_log_prob) surr1 = adv * ratio surr2 = adv * torch.clamp(ratio, 1 - cfg.alg.clip_range, 1 + cfg.alg.clip_range) pg_loss = -torch.mean(torch.min(surr1, surr2)) loss = pg_loss - entropy * cfg.alg.ent_coef + \ vf_loss * cfg.alg.vf_coef return loss, pg_loss, vf_loss, ratio def cal_val_loss(self, val, old_val, ret): if cfg.alg.clip_vf_loss: clipped_val = old_val + torch.clamp( val - old_val, -cfg.alg.clip_range, cfg.alg.clip_range) vf_loss1 = torch.pow(val - ret, 2) vf_loss2 = torch.pow(clipped_val - ret, 2) vf_loss = 0.5 * torch.mean(torch.max(vf_loss1, vf_loss2)) else: # val = torch.squeeze(val) vf_loss = 0.5 * self.val_loss_criterion(val, ret) return vf_loss def train_mode(self): self.actor.train() self.critic.train() def eval_mode(self): self.actor.eval() self.critic.eval() def decay_lr(self): self.lr_scheduler.step() def get_lr(self): cur_lr = self.lr_scheduler.get_lr() lrs = {'policy_lr': cur_lr[0]} if len(cur_lr) > 1: lrs['value_lr'] = cur_lr[1] return lrs def decay_clip_range(self): cfg.alg.clip_range -= self.clip_range_decay_rate def save_model(self, is_best=False, step=None): self.save_env(cfg.alg.model_dir) data_to_save = { 'step': step, 'actor_state_dict': self.actor.state_dict(), 'critic_state_dict': self.critic.state_dict(), 'optim_state_dict': self.optimizer.state_dict(), 'lr_scheduler_state_dict': self.lr_scheduler.state_dict() } if cfg.alg.linear_decay_clip_range: data_to_save['clip_range'] = cfg.alg.clip_range data_to_save['clip_range_decay_rate'] = self.clip_range_decay_rate save_model(data_to_save, cfg.alg, is_best=is_best, step=step) def load_model(self, step=None, pretrain_model=None): self.load_env(cfg.alg.model_dir) ckpt_data = load_ckpt_data(cfg.alg, step=step, pretrain_model=pretrain_model) load_state_dict(self.actor, ckpt_data['actor_state_dict']) load_state_dict(self.critic, ckpt_data['critic_state_dict']) if pretrain_model is not None: return self.optimizer.load_state_dict(ckpt_data['optim_state_dict']) self.lr_scheduler.load_state_dict(ckpt_data['lr_scheduler_state_dict']) if cfg.alg.linear_decay_clip_range: self.clip_range_decay_rate = ckpt_data['clip_range_decay_rate'] cfg.alg.clip_range = ckpt_data['clip_range'] return ckpt_data['step'] def print_param_grad_status(self): logger.info('Requires Grad?') logger.info('================== Actor ================== ') for name, param in self.actor.named_parameters(): print(f'{name}: {param.requires_grad}') logger.info('================== Critic ================== ') for name, param in self.critic.named_parameters(): print(f'{name}: {param.requires_grad}')
class Trainer(object): def __init__(self, args): self.name = args.name self.max_epoch = args.max_epoch self.lr = args.lr self.weight_decay = args.weight_decay self.log = args.log self.out_every = args.out_every self.pos_w = args.pos_w, self.LAMBDA = args.LAMBDA if args.cuda_dev: torch.cuda.set_device(args.cuda_dev[0]) self.cuda_dev = f'cuda:{args.cuda_dev[0]}' self.device = 'cuda' else: self.cuda_dev = None self.device = 'cpu' print(f'Using {self.device}') self.z_dim = args.z_dim self.batch_size = args.batch_size self.start_save = args.start_save self.start_epoch = args.start_epoch self.ckpt_dir = os.path.join(args.ckpt_dir, self.name) if args.data_type == 'simATAC': self.dataset = SimATAC(args.setting, args.signal, args.frags, args.bin_size, conv=args.conv) elif args.data_type == 'atlas': self.dataset = MouseAtlas(cutoff=CUT_OFF) elif args.data_type == 'pbmc': self.dataset = PBMC() elif args.data_type == 'mergeSim': if args.num: self.dataset = MergeSim(args.num) else: self.dataset = MergeSim() else: raise Exception(f'Dataset {args.data_type} does not exist!') self.dataloader = DataLoader(self.dataset, batch_size=args.batch_size, shuffle=True, num_workers=3 * len(args.cuda_dev), pin_memory=True, drop_last=True) input_dim = self.dataset.padto if args.model_type == 'inv': if args.sample_batch: self.de_batch = True self.vae = VAE2(input_dim, args.z_dim, batch=True) else: self.de_batch = False self.vae = VAE2(input_dim, args.z_dim) self.vaeI = VAEInv(self.vae) self.model = nn.DataParallel(self.vaeI, device_ids=args.cuda_dev) else: raise Exception(f'Model type {args.model_type} does not exist!') self.model_type = args.model_type if args.load_ckpt: self.load_ckpt(args.load_ckpt) # if os.path.isfile(args.load_ckpt): # print('Loading ' + args.load_ckpt) # if self.cuda_dev: # self.model.module.load_state_dict(torch.load(args.load_ckpt, map_location=self.cuda_dev)) # else: # self.model.module.load_state_dict(torch.load(args.load_ckpt, map_location='cpu')) # print('Finished Loading ckpt...') # else: # raise Exception(args.load_ckpt + "\nckpt does not exist!") self.model.to(self.device) self.optim = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) self.cycle = CYCLE * self.dataset.__len__() // self.batch_size // len( args.cuda_dev) lr_lmd = lambda epoch: 0.995**epoch self.le_scdlr = LambdaLR(self.optim, lr_lambda=lr_lmd) self.le_scdlr.last_epoch = self.start_epoch - 1 def load_ckpt(self, ckpt_pth): if os.path.isfile(ckpt_pth): print('Loading ' + ckpt_pth) if self.cuda_dev: self.model.module.load_state_dict( torch.load(ckpt_pth, map_location=self.cuda_dev)) else: self.model.module.load_state_dict( torch.load(ckpt_pth, map_location='cpu')) print('Finished Loading ckpt...') else: raise Exception(ckpt_pth + "\nckpt does not exist!") def warm_up(self): if not os.path.exists(self.ckpt_dir): print(f'Making dir {self.ckpt_dir}') os.makedirs(self.ckpt_dir) self.model.train() self.pbar = tqdm(total=WARM_UP) total_iter = 0 for step in range(WARM_UP): for x, s, l in self.dataloader: l = l.unsqueeze(1).float().to(self.device).log() l = (l - self.dataset.d_mean) / self.dataset.d_std total_iter += 1 x = x.float().to(self.device) if self.model_type == 'adv': _, _, _, rec, _ = self.model(x, l) elif self.model_type == 'inv': if self.de_batch: s = s.unsqueeze(1).float().to(self.device) _, _, _, rec = self.model(x, l, b=s) else: _, _, _, rec = self.model(x, l) else: _, _, _, rec = self.model(x) pos_weight = torch.Tensor([self.pos_w]).to(self.device) bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight) # rec_loss = focal(rec.view(-1), x.view(-1).long()) rec_loss = bce(rec, x) self.optim.zero_grad() rec_loss.backward() self.optim.step() # if total_iter%50 == 0: # self.pbar.write(f'[{total_iter}] vae_recon_loss:{rec_loss.item()}') self.pbar.update(1) torch.save(self.model.module.state_dict(), os.path.join(self.ckpt_dir, 'warmup.pt')) self.pbar.write("[Warmup Finished]") self.pbar.close() def rec_all(self, batch_size=1, same_depth=False): dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False) labels = [] self.model.eval() for i, dp in tqdm(enumerate(dataloader)): x, l, d = dp x = x.float().to(self.device) labels = labels + l if same_depth: d = d.unsqueeze(1).float().to(self.device).log() # same_depth = (same_depth - self.dataset.d_mean) / self.dataset.d_std d = (torch.ones_like(d) * same_depth).log() else: d = d.unsqueeze(1).float().to(self.device).log() with torch.no_grad(): _, _, _, rec = self.model.forward(x, d) # rec = torch.sigmoid(rec).cpu() rec = rec.cpu() if i == 0: out = rec else: out = torch.cat((out, rec)) return out, labels def rec_batch(self, batch_size=1, same_depth=False): dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False) labels = [] self.model.eval() for i, dp in tqdm(enumerate(dataloader)): x, l, d = dp x = x.float().to(self.device) labels = labels + l if same_depth: d = d.unsqueeze(1).float().to(self.device).log() # same_depth = (same_depth - self.dataset.d_mean) / self.dataset.d_std d = (torch.ones_like(d) * same_depth).log() else: d = d.unsqueeze(1).float().to(self.device).log() b = torch.zeros_like(d).float().to(self.device) with torch.no_grad(): _, _, _, rec = self.model.forward(x, d, b) # rec = torch.sigmoid(rec).cpu() rec = rec.cpu() if i == 0: out = rec else: out = torch.cat((out, rec)) return out, labels def inv_train(self): if not os.path.exists(self.ckpt_dir): print(f'Making dir {self.ckpt_dir}') os.makedirs(self.ckpt_dir) self.model.train() kl_list, rec_list = [], [] print('Inv Training started') self.pbar = tqdm(total=self.max_epoch - self.start_epoch) total_iter = (self.start_epoch - 1) * self.dataset.__len__() // self.batch_size + 1 for epoch in range(self.start_epoch, self.start_epoch + self.max_epoch): epoch_kl, epoch_rec = [], [] kl_w = np.min([ 2 * (total_iter - (total_iter // self.cycle) * self.cycle) / self.cycle, 1 ]) for x1, s1, l1 in self.dataloader: x1 = x1.float().to(self.device) l1 = l1.log() l1 = (l1 - self.dataset.d_mean) / self.dataset.d_std l1 = l1.unsqueeze(1).float().to(self.device) if self.de_batch: s1 = s1.unsqueeze(1).float().to(self.device) z_mean, z_log_var, _, rec = self.model(x1, l1, b=s1) else: z_mean, z_log_var, _, rec = self.model(x1, l1) mean = torch.zeros_like(z_mean) var = torch.ones_like(z_log_var) kld_z = kl(Normal(z_mean, torch.exp(z_log_var).sqrt()), Normal(mean, var)).sum() pos_weight = torch.Tensor([self.pos_w]).to(self.device) bce = F.binary_cross_entropy_with_logits(rec, x1, weight=pos_weight, reduction='sum') rec_loss = bce m_kld = apprx_kl( z_mean, torch.exp(z_log_var).sqrt()).mean() - 0.5 * self.z_dim loss = kld_z * kl_w + ( 1 + self.LAMBDA) * rec_loss + m_kld * kl_w * self.LAMBDA self.optim.zero_grad() loss.backward() self.optim.step() epoch_kl.append(kld_z.item()) epoch_rec.append(bce.item()) total_iter += 1 kl_list.append(np.mean(epoch_kl)) rec_list.append(np.mean(epoch_rec)) self.pbar.update(1) self.le_scdlr.step() # self.pbar.write(f'[{epoch}], iter {total_iter}') if epoch % self.out_every == 0: if epoch > self.start_save: torch.save(self.model.module.state_dict(), os.path.join(self.ckpt_dir, f'{epoch}.pt')) logdata = { 'iter': list(range(self.start_epoch, epoch + 1)), 'kl': kl_list, 'bce': rec_list } df = pd.DataFrame(logdata) df.to_csv(os.path.join(self.ckpt_dir, 'inv' + self.log), index=False) self.pbar.write("[Inv Training Finished]") self.pbar.close() def encode_adv(self, batch_size=1000): dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False) labels = [] latent = torch.zeros(self.dataset.__len__(), self.z_dim) depth = torch.zeros(self.dataset.__len__()) self.model.eval() for i, dp in tqdm(enumerate(dataloader)): x, l, d = dp x = x.float().to(self.device) if self.de_batch: labels = labels + list(l) else: labels = labels + l depth[i * batch_size:(i + 1) * batch_size] = d d = d.log() d = (d - self.dataset.d_mean) / self.dataset.d_std d = d.unsqueeze(1).float().to(self.device) with torch.no_grad(): z_mean, _ = self.model.forward(x, d, no_rec=True) # z_mean, _, _, _ = self.model(x) latent[i * batch_size:(i + 1) * batch_size] = z_mean.cpu() return latent, labels, depth
def run_COPT(game, num_iter=5000, lr=0.5, seed=1234, biased=False, shuffling=False, lr_schedule=None, hamiltonian_coeff=10, **kwargs): config = Config(dict(mode="consensus_opt", num_iter=num_iter, lr=lr, seed=seed, hamiltonian_coeff=hamiltonian_coeff, shuffling=shuffling)) torch.manual_seed(seed) game.reset() sgd = optim.SGD(game.parameters(), lr=lr) if lr_schedule is not None: lr_schedule = SchedulerLR(lr_schedule) scheduler = LambdaLR(sgd, lr_schedule) else: scheduler = LambdaLR(sgd, lambda k: 1.) logger = defaultdict(list) if kwargs["output"] is not None: path = os.path.join(kwargs["output"], config.name, str(seed)) config["path"] = path if not os.path.exists(path): os.makedirs(os.path.join(path, "results")) config["name"] = config.name with open(os.path.join(path, "config.json"), "w") as f: json.dump(config, f, default=lambda x: "non-serializable") if shuffling: game.shuffle() n_samples = 0 start_time = time.time() for i in tqdm(range(num_iter)): index1 = game.sample() index2 = game.sample() if biased is True: grad1 = game.compute_grad(index1) grad2 = grad1 hamiltonian = compute_hamiltonian(grad1) n_samples += 1 elif biased == "copt": grad1 = game.compute_grad(torch.cat([index1, index2])) grad2 = grad1 hamiltonian = compute_hamiltonian(grad1) n_samples += 2 elif biased is False: grad1 = game.compute_grad(index1) grad2 = game.compute_grad(index2) hamiltonian = compute_hamiltonian(grad1) n_samples += 2 else: raise ValueError() grad_H = autograd.grad(hamiltonian, game.parameters()) for p, g1, g2, gH in zip(game.parameters(), grad1, grad2, grad_H): p.grad = 0.5*(g1+g2) + hamiltonian_coeff*gH sgd.step() scheduler.step() metrics = game.compute_metrics() for key, value in metrics.items(): logger[key].append(value) logger["lr"].append(scheduler.get_last_lr()) logger["num_samples"].append(n_samples) logger["time"].append(time.time()-start_time) if i % 10000 == 0: with open(os.path.join(path, "results.json"), "w") as f: json.dump(logger, f) return logger, config
def run_SHGD(game, num_iter=5000, lr=None, seed=1234, save_params=False, biased=False, shuffling=False, lr_schedule=None, **kwargs): if lr is None: lr = float(1/(2*game.L)) if lr_schedule == "optimal": lr_schedule = int(4*(game.L/game.mu)) config = Config(dict(mode="shgd", num_iter=num_iter, lr=lr, seed=seed, biased=biased, lr_schedule=lr_schedule, shuffling=shuffling)) torch.manual_seed(seed) game.reset() sgd = optim.SGD(game.parameters(), lr=lr) if lr_schedule is not None: lr_schedule = SchedulerLR(lr_schedule) scheduler = LambdaLR(sgd, lr_schedule) else: scheduler = LambdaLR(sgd, lambda k: 1.) logger = defaultdict(list) if kwargs["output"] is not None: path = os.path.join(kwargs["output"], config.name, str(seed)) config["path"] = path if not os.path.exists(path): os.makedirs(os.path.join(path, "results")) config["name"] = config.name with open(os.path.join(path, "config.json"), "w") as f: json.dump(config, f, default=lambda x: "non-serializable") if shuffling: game.shuffle() n_samples = 0 params_history = [] start_time = time.time() for i in tqdm(range(num_iter)): index1 = game.sample() index2 = game.sample() if biased is True: hamiltonian = game.compute_hamiltonian(index1) n_samples += 1 elif biased == "copt": hamiltonian = game.compute_hamiltonian(torch.cat([index1, index2])) n_samples += 2 elif biased is False: hamiltonian = game.compute_hamiltonian(index1, index2) n_samples += 2 else: raise ValueError() grad = autograd.grad(hamiltonian, game.parameters()) for p, g in zip(game.parameters(), grad): p.grad = g sgd.step() scheduler.step() metrics = game.compute_metrics() for key, value in metrics.items(): logger[key].append(value) #logger["lr"].append(scheduler.get_last_lr()) logger["num_samples"].append(n_samples) logger["time"].append(time.time()-start_time) if save_params: params_history.append(copy.deepcopy(game.state_dict())) if i % 10000 == 0: with open(os.path.join(path, "results.json"), "w") as f: json.dump(logger, f) logger["params"] = params_history return logger, config
x_train, x_test, y_train, y_test = get_data() train_dataloader = inf_data_gen(x_train, y_train, cfg.TRAIN.BATCH_SIZE) X_test = torch.Tensor(x_test).to(cfg.SYSTEM.DEVICE) Y_test = torch.Tensor(y_test).to(cfg.SYSTEM.DEVICE) net = Net(D=cfg.MODEL.D, W=cfg.MODEL.W) net.to(cfg.SYSTEM.DEVICE) optimizer = torch.optim.SGD(net.parameters(), lr=cfg.TRAIN.LEARNING_RATE) scheduler = LambdaLR(optimizer, lr_lambda=inv_root_lr) pbar = tqdm(train_dataloader, total=cfg.TRAIN.STEPS) for n_iter, (X, T) in enumerate(pbar, start=1): X, T = X.to(cfg.SYSTEM.DEVICE), T.to(cfg.SYSTEM.DEVICE) optimizer.zero_grad() net.train() train(net, X, T, optimizer, n_iter) if n_iter % 5000 == 0: net.eval() test(net, X_test, Y_test, n_iter) _WRITER.add_scalar('LR', get_lr(optimizer), n_iter) scheduler.step(n_iter) if n_iter % (cfg.TRAIN.STEPS / 10) == 0: analysis(net, x_train, y_train, n_iter) if n_iter > cfg.TRAIN.STEPS: break
def main(): global opt # train data loader train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers)) # create model model = models.VAMetric_conv() if opt.init_model != '': print('loading pretrained model from {0}'.format(opt.init_model)) model.load_state_dict(torch.load(opt.init_model)) # Contrastive Loss criterion = models.conv_loss_dqy() if opt.cuda: print('shift model and criterion to GPU .. ') model = model.cuda() criterion = criterion.cuda() # optimizer # optimizer = optim.SGD(model.parameters(), lr=opt.lr, # momentum=opt.momentum, # weight_decay=opt.weight_decay) optimizer = optim.Adam(model.parameters(), lr=opt.lr) # optimizer = optim.SGD(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay, momentum=opt.momentum) # optimizer = optim.Adadelta(params=model.parameters(), lr=opt.lr) # adjust learning rate every lr_decay_epoch lambda_lr = lambda epoch: opt.lr_decay**( (epoch + 1) // opt.lr_decay_epoch) # poly policy scheduler = LR_Policy(optimizer, lambda_lr) resume_epoch = 0 global dis1_rec global dis2_rec global loss_rec loss_rec = [] dis1_rec = [] dis2_rec = [] ######### to test each epoch parser = OptionParser() parser.add_option('--config', type=str, help="evaluation configuration", default="./configs/test_config.yaml") (opts_test, args) = parser.parse_args() opts_test = Config(opts_test.config) test_video_dataset = VideoFeatDataset(opts_test.data_dir, opts_test.video_flist, which_feat='vfeat') test_audio_dataset = VideoFeatDataset(opts_test.data_dir, opts_test.audio_flist, which_feat='afeat') test_video_loader = torch.utils.data.DataLoader( test_video_dataset, batch_size=opts_test.batchSize, shuffle=False, num_workers=int(opts_test.workers)) test_audio_loader = torch.utils.data.DataLoader( test_audio_dataset, batch_size=opts_test.batchSize, shuffle=False, num_workers=int(opts_test.workers)) ######## # another test for git for epoch in range(resume_epoch, opt.max_epochs): ################################# # train for one epoch ################################# train(train_loader, model, criterion, optimizer, epoch, opt, test_video_loader, test_audio_loader, opts_test) scheduler.step() ################################## # save checkpoints ################################## # save model every 10 epochs if ((epoch + 1) % opt.epoch_save) == 0: path_checkpoint = '{0}/{1}_state_epoch{2}.pth'.format( opt.checkpoint_folder, opt.prefix, epoch + 1) utils.save_checkpoint(model.state_dict(), path_checkpoint) plt.figure(1) plt.subplot(1, 2, 1) plt.plot(loss_rec) plt.legend('loss') plt.subplot(1, 2, 2) plt.plot(dis1_rec) plt.plot(dis2_rec) plt.legend(('distance between positives', 'distance between negatives')) plt.show() plt.savefig("./figures/conv.jpg")
def run_training( source, target, dataset_root, net_name, da_method, max_iter, stop_iter, test_iter, logdir, run_name, gpu_id, load_workers, config, test_src: bool = False, use_tqdm: bool = True, kill_diverging: bool = False): dev = torch.device(f'cuda:{gpu_id}') if kill_diverging: assert test_src # Get config # Config arrives here (from BOHB or direct cli invocation) as a dictionary like # {'disc.dropout': 0.5, 'net.bottleneck_size_log': 9} # We separate it in something like # {'disc': {'dropout': 0.5'}, 'net': {'bottleneck_size_log': 9}} config = split_dict(config) # Disc args are not meaningful without DA if da_method != 'so': # Default disc args disc_args = { 'dropout': 0.5, 'num_fc_layers': 3, 'hidden_size_log': 10 } # Update with the ones coming from config (if any) disc_args.update(config.get('disc', {})) # Some args might be defined as log2. Replace them (bottleneck_size_log -> bottleneck_size) remove_log_hps(disc_args) # Print disc args print(f"Discriminator config: {disc_args}") # Very similar, but for the backbone net_args = { 'use_bottleneck': da_method != 'so', 'bottleneck_size_log': 9 } net_args.update(config.get('net', {})) remove_log_hps(net_args) print(f"Backbone config: {net_args}") # Now net_args and disc_args are ready to be passed to the network constructors as **kwargs :) bs, lr, wd = config['base']['bs'], config['base']['lr'], config['base']['wd'] # Load datasets and their number o classes dset_src_train, dset_src_test, dset_trg_train, dset_trg_test, num_classes = \ prepare_datasets(source, target, dataset_root) dload_src_train = DataLoader(dset_src_train, batch_size=bs, shuffle=True, num_workers=load_workers, drop_last=True) dload_src_test = DataLoader(dset_src_test, batch_size=bs, shuffle=False, num_workers=load_workers) dload_trg_train = DataLoader(dset_trg_train, batch_size=bs, shuffle=True, num_workers=load_workers, drop_last=True) dload_trg_test = DataLoader(dset_trg_test, batch_size=bs, shuffle=False, num_workers=load_workers) print(f"Source samples: {len(dset_src_train)}") print(f"Target samples: {len(dset_trg_train)}") print(f"Num classes: {num_classes}") # Build network base_network = resnet.ResNetFc( resnet_name=net_name, num_classes=num_classes, plug_position=7, **net_args ).to(dev) params = base_network.get_parameters(lr, wd) # Source only has no secondary branches if da_method != 'so': disc_classes = { # ( -> confusion matrix) 'alda': num_classes, # ( -> binary domain classifier) 'dann': 2 }[da_method] discriminator = resnet.Discriminator(in_feature=base_network.output_size(), num_classes=disc_classes, **disc_args).to(dev) params += discriminator.get_parameters(lr, wd) # Define optimizer optimizer = opt.SGD( params=params, lr=lr, momentum=0.9, weight_decay=wd, nesterov=True ) # Lr policy lr_schedule = LambdaLR(optimizer, lr_lambda=lambda it: (1 + 0.001 * it) ** (-0.75)) # Logger writer = Logger(logdir=logdir, run_name=run_name, use_tb=True, use_tqdm=use_tqdm) # Classification loss ce_loss = nn.CrossEntropyLoss() # Train loop len_train_source = len(dload_src_train) len_train_target = len(dload_trg_train) lambda_val = 0. # We store all the metrics here metrics = [] all_pseudolabels = [] with writer.progress(total=stop_iter, desc="Training") as pb: for i in range(stop_iter): if (i + 1) % test_iter == 0: print(f"Iteration: {i + 1} / {stop_iter} (max: {max_iter})") print("Testing...") base_network.train(False) # This dict contains metric-name -> value pairs for the current epoch new_metrics = {} if test_src: test_result, _, src_test_feats = test(dload_src_test, base_network, device=dev) # Print accuracy print("Source accuracy: {:.3f} %".format(test_result['accuracy'] * 100)) # Add the source metrics to the dict (with the source_ prefix) new_metrics.update({f'source_{k}': v for k, v in test_result.items()}) test_result, epoch_pseudolabels, _ = test(dload_trg_test, base_network, device=dev, source_feats=src_test_feats) all_pseudolabels.append(epoch_pseudolabels) print(f"Target accuracy: {test_result['accuracy'] * 100:.3f} %") writer.add_scalar('train/base_lr', lr_schedule.get_last_lr()[0], i) writer.add_scalar('train/lambda', lambda_val, i) new_metrics.update({f'target_{k}': v for k, v in test_result.items()}) # Add all the new metrics to tensorboard logs add_scalars(writer, new_metrics, global_step=i, prefix='test/') # Add a column with iteration number new_metrics.update({'iter': i}) # Concatenate to older epoch metrics metrics.append(new_metrics) # Kill this training if source loss goes too high if kill_diverging and new_metrics['source_class_loss'] > SOURCE_LOSS_THRESHOLD: if len(metrics) > 0 and new_metrics['source_class_loss'] > metrics[-1]['source_class_loss']: print(f"Increasing source_class_loss exceeds maximum allowed source loss ({new_metrics['source_class_loss']} > {SOURCE_LOSS_THRESHOLD})") break # Train one iteration base_network.train(True) if da_method != 'so': discriminator.train(True) optimizer.zero_grad() # Reset data loops if required if i % len_train_source == 0: iter_source = iter(dload_src_train) if i % len_train_target == 0: iter_target = iter(dload_trg_train) # Load source inputs_source, labels_source = iter_source.next() inputs_source, labels_source = map_to_device(dev, (inputs_source, labels_source)) # Compute source features and classification output outputs_source, features_source = base_network(inputs_source) # Classification loss classifier_loss = ce_loss(outputs_source, labels_source) # Actual DA part if da_method != 'so': # Load target samples without target labels inputs_target, _ = iter_target.next() inputs_target = inputs_target.to(dev) # Compute target features and classification output outputs_target, features_target = base_network(inputs_target) # Source and target features features = torch.cat((features_source, features_target), dim=0) # Source and target classification outputs (-> softmax) outputs = torch.cat((outputs_source, outputs_target), dim=0) softmax_out = nn.Softmax(dim=1)(outputs) # CORE if da_method == 'dann': p = float(i / max_iter) lambda_val = 2. / (1 + np.exp(-10 * p)) - 1 ad_out = discriminator(features, lambda_val) adv_loss = loss.DANN_loss(ad_out) transfer_loss = adv_loss if (i + 1) % test_iter == 0: print("Transfer loss: {:.3f}".format(transfer_loss.item())) elif da_method == 'alda': p = float(i / max_iter) lambda_val = 2. / (1 + np.exp(-10 * p)) - 1 ad_out = discriminator(features, lambda_val) adv_loss, reg_loss, correct_loss = loss.ALDA_loss(ad_out, labels_source, softmax_out, threshold=0.9) transfer_loss = adv_loss + lambda_val * correct_loss if (i + 1) % test_iter == 0: print("Transfer loss: {:.3f}, reg loss {:.3f}%".format(transfer_loss.item(), reg_loss.item())) # Backpropagate reg_loss only through the discriminator with base_network.freeze(): reg_loss.backward(retain_graph=True) # END CORE else: transfer_loss = 0 total_loss = classifier_loss + config['base']['weight_da'] * transfer_loss total_loss.backward() optimizer.step() lr_schedule.step() if (i + 1) % test_iter == 0 and da_method != 'so': writer.add_scalar('train/transfer_loss', transfer_loss.item(), i) pb.update(1) # Convert list of dicts to dataframe containing metrics metrics = pd.DataFrame(metrics) # Compute global-pseudolabel accuracy all_pseudolabels = np.array(all_pseudolabels) global_pseudolabels = compute_time_consistent_pseudolabels(all_pseudolabels, num_classes) pseudolabel_acc = np.equal(all_pseudolabels, global_pseudolabels).sum(axis=1) / global_pseudolabels.shape[0] # Add it to the metrics dataframe metrics['target_pseudolabels'] = pseudolabel_acc # Save the metrics with open(os.path.join(logdir, run_name, "metrics.pkl"), "wb") as fp: pickle.dump(metrics, fp) # Log global pseudolabel accuracy to tensorboard for i in range(len(all_pseudolabels)): writer.add_scalar('test/target_pseudolabels', float(pseudolabel_acc[i]), i * test_iter) return metrics
def train(self) -> None: r"""Main method for training PPO. Returns: None """ profiling_wrapper.configure( capture_start_step=self.config.PROFILING.CAPTURE_START_STEP, num_steps_to_capture=self.config.PROFILING.NUM_STEPS_TO_CAPTURE, ) self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.obs_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, ) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) batch = apply_obs_transforms_batch(batch, self.obs_transforms) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), ) window_episode_stats: DefaultDict[str, deque] = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES ), # type: ignore ) with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for update in range(self.config.NUM_UPDATES): profiling_wrapper.on_start_step() profiling_wrapper.range_push("train update") if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() # type: ignore if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) profiling_wrapper.range_push("rollouts loop") for _step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps profiling_wrapper.range_pop() # rollouts loop ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time for k, v in running_episode_stats.items(): window_episode_stats[k].append(v.clone()) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar("reward", deltas["reward"] / deltas["count"], count_steps) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) losses = [value_loss, action_loss] writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join("{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint(f"ckpt.{count_checkpoints}.pth", dict(step=count_steps)) count_checkpoints += 1 profiling_wrapper.range_pop() # train update self.envs.close()
def train( run_name: str, # Data train_filepath: str = CSNJS_TRAIN_FILEPATH, eval_filepath: str = CSNJS_VALID_FILEPATH, spm_filepath: str = SPM_UNIGRAM_FILEPATH, program_mode="identity", eval_program_mode="identity", label_mode="identifier", num_workers=1, limit_dataset_size=-1, # Model model_type="transformer", n_decoder_layers=4, d_model: int = 512, resume_path: str = "", resume_encoder_name: str = "encoder_q", # encoder_q, encoder_k, encoder resume_project: bool = False, # Optimization train_decoder_only: bool = False, num_epochs: int = 50, save_every: int = 2, batch_size: int = 256, lr: float = 8e-4, adam_beta1: float = 0.9, adam_beta2: float = 0.98, use_lr_warmup: bool = True, loss_type="nll_token", # nll_token or nll_sequence # Loss subword_regularization_alpha: float = 0, # Computational use_cuda: bool = True, auto_test: bool = True, seed: int = 0, ): """Train model""" torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) run_dir = RUN_DIR / run_name run_dir.mkdir(exist_ok=True, parents=True) logger.add(str((run_dir / "train.log").resolve())) logger.info(f"Saving logs, model checkpoints to {run_dir}") config = locals() logger.info(f"Config: {config}") wandb.init(name=run_name, config=config, job_type="training", project="identifier-prediction", entity="ml4code") if use_cuda: assert torch.cuda.is_available( ), "CUDA not available. Check env configuration, or pass --use_cuda False" train_augmentations = [ { "fn": "sample_lines", "line_length_pct": 0.5 }, { "fn": "insert_var_declaration", "prob": 0.5 }, { "fn": "rename_variable", "prob": 0.5 }, ] sp = spm.SentencePieceProcessor() sp.Load(spm_filepath) pad_id = sp.PieceToId("[PAD]") # Create training dataset and dataloader logger.info(f"Training data path {train_filepath}") train_dataset = get_csnjs_dataset(train_filepath, label_mode=label_mode, limit_size=limit_dataset_size) logger.info(f"Training dataset size: {len(train_dataset)}") train_loader = javascript_dataloader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, augmentations=train_augmentations, sp=sp, program_mode=program_mode, subword_regularization_alpha=subword_regularization_alpha, ) # Create eval dataset and dataloader logger.info(f"Eval data path {eval_filepath}") eval_dataset = get_csnjs_dataset(eval_filepath, label_mode=label_mode, limit_size=limit_dataset_size) logger.info(f"Eval dataset size: {len(eval_dataset)}") eval_loader = javascript_dataloader( eval_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, augmentations=[], sp=sp, program_mode=eval_program_mode, subword_regularization_alpha=subword_regularization_alpha, ) # Create model pad_id = sp.PieceToId("[PAD]") if model_type == "transformer": model = TransformerModel(n_tokens=sp.GetPieceSize(), pad_id=pad_id, n_decoder_layers=n_decoder_layers, d_model=d_model) logger.info( f"Created TransformerModel with {count_parameters(model)} params") elif model_type == "lstm": model = Seq2SeqLSTM(n_tokens=sp.GetPieceSize(), pad_id=pad_id, d_model=d_model) logger.info( f"Created Seq2SeqLSTM with {count_parameters(model)} params") # Load checkpoint if resume_path: logger.info( f"Resuming training from checkpoint {resume_path}, resume_encoder_name={resume_encoder_name}" ) checkpoint = torch.load(resume_path) pretrained_state_dict = checkpoint["model_state_dict"] encoder_state_dict = {} assert resume_encoder_name in ["encoder_k", "encoder_q", "encoder"] for key, value in pretrained_state_dict.items(): if key.startswith(resume_encoder_name + ".") and "project_layer" not in key: remapped_key = key[len(resume_encoder_name + "."):] logger.debug( f"Remapping checkpoint key {key} to {remapped_key}. Value mean: {value.mean().item()}" ) encoder_state_dict[remapped_key] = value if key.startswith( resume_encoder_name + ".") and "project_layer.0." in key and resume_project: remapped_key = key[len(resume_encoder_name + "."):] logger.debug( f"Remapping checkpoint project key {key} to {remapped_key}. Value mean: {value.mean().item()}" ) encoder_state_dict[remapped_key] = value model.encoder.load_state_dict(encoder_state_dict, strict=False) logger.info(f"Loaded state dict from {resume_path}") logger.info(f"Loaded keys: {encoder_state_dict.keys()}") # Set up optimizer model = nn.DataParallel(model) model = model.cuda() if use_cuda else model wandb.watch(model, log="all") params = model.module.decoder.parameters( ) if train_decoder_only else model.parameters() optimizer = torch.optim.Adam(params, lr=lr, betas=(adam_beta1, adam_beta2), eps=1e-9) if use_lr_warmup: scheduler = get_linear_schedule_with_warmup( optimizer, 5000, len(train_loader) * num_epochs) else: scheduler = LambdaLR(optimizer, lr_lambda=lambda x: 1.0) global_step = 0 min_eval_loss = float("inf") for epoch in tqdm.trange(1, num_epochs + 1, desc="training", unit="epoch", leave=False): logger.info(f"Starting epoch {epoch}\n") if train_decoder_only: model.module.encoder.eval() model.module.decoder.train() else: model.train() pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}") for X, Y, X_lengths, Y_lengths in pbar: if use_cuda: X = X.cuda() Y = Y.cuda() X_lengths, Y_lengths = X_lengths.cuda(), Y_lengths.cuda() optimizer.zero_grad() # NOTE: X and Y are [B, max_seq_len] tensors (batch first) logits = model(X, Y[:, :-1], X_lengths, Y_lengths) if loss_type == "nll_sequence": loss = F.cross_entropy(logits.transpose(1, 2), Y[:, 1:], ignore_index=pad_id, reduction='sum') loss = loss / X.size( 0 ) # Average over num sequences, not target sequence lengths # Thus, minimize bits per sequence. elif loss_type == "nll_token": loss = F.cross_entropy( logits.transpose(1, 2), Y[:, 1:], ignore_index=pad_id, ) loss.backward() optimizer.step() scheduler.step() # Log loss global_step += 1 wandb.log( { "epoch": epoch, f"label-{label_mode}/train_loss": loss.item(), "lr": scheduler.get_last_lr()[0] }, step=global_step) pbar.set_description(f"epoch {epoch} loss {loss.item():.4f}") # Evaluate logger.info( f"Evaluating model after epoch {epoch} ({global_step} steps)...") max_decode_len = 20 if label_mode == "identifier" else 200 eval_loss = _evaluate(model, eval_loader, sp, use_cuda=use_cuda, max_decode_len=max_decode_len, loss_type=loss_type) logger.info( f"Evaluation loss after epoch {epoch} ({global_step} steps): {eval_loss:.4f}" ) wandb.log({ "epoch": epoch, f"label-{label_mode}/eval_loss": eval_loss }, step=global_step) # Save checkpoint if save_every and epoch % save_every == 0 or eval_loss < min_eval_loss: checkpoint = { "model_state_dict": model.module.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch, "global_step": global_step, "config": config, "eval_loss": eval_loss, } if eval_loss < min_eval_loss: logger.info( f"New best evaluation loss: prev {min_eval_loss:.4f} > new {eval_loss:.4f}" ) min_eval_loss = eval_loss model_file = run_dir / "ckpt_best.pth" else: model_file = run_dir / f"ckpt_ep{epoch:04d}.pth" logger.info(f"Saving checkpoint to {model_file}...") torch.save(checkpoint, str(model_file.resolve())) wandb.save(str(model_file.resolve())) logger.info("Done.") if auto_test: best_ckpt = run_dir / "ckpt_best.pth" test( str(best_ckpt.resolve()), CSNJS_TEST_FILEPATH, spm_filepath, program_mode, label_mode, num_workers, -1, n_decoder_layers=n_decoder_layers, )
def training( # noqa: C901 self, tproblem, hyperparams, num_epochs, print_train_iter, train_log_interval, tb_log, tb_log_dir, **training_params, ): """Training loop for this runner. Args: tproblem (deepobs.pytorch.testproblems.testproblem): The testproblem instance to train on. hyperparams (dict): The optimizer hyperparameters to use for the training. num_epochs (int): The number of training epochs. print_train_iter (bool): Whether to print the training progress at every train_log_interval train_log_interval (int): Mini-batch interval for logging. tb_log (bool): Whether to use tensorboard logging or not tb_log_dir (str): The path where to save tensorboard events. **training_params (dict): Kwargs for additional training parameters that will be used by the cockpit. Returns: dict: Output of the training loop """ opt = self._optimizer_class(tproblem.net.parameters(), **hyperparams) # Using a LR Scheduler lr_sched = training_params["lr_schedule"](num_epochs) scheduler = LambdaLR(opt, lr_lambda=lr_sched) # COCKPIT: Initialize it # logpath = self._get_cockpit_logpath() # Integrate BackPACK extend_with_access_unreduced_loss(tproblem) trainable_params = [ p for p in tproblem.net.parameters() if p.requires_grad ] cockpit = Cockpit(trainable_params, quantities=self._quantities) plotter = CockpitPlotter(secondary_screen=self._secondary_screen) if self._plot_schedule is not None: plot_schedule = self._plot_schedule else: warnings.warn( "You are using plot_interval, which will be deprecated. " + "Use plot_schedule instead") plot_schedule = schedules.linear(training_params["plot_interval"]) # Lists to log train/test loss and accuracy. train_losses = [] valid_losses = [] test_losses = [] train_accuracies = [] valid_accuracies = [] test_accuracies = [] minibatch_train_losses = [] if tb_log: try: from torch.utils.tensorboard import SummaryWriter summary_writer = SummaryWriter(log_dir=tb_log_dir) except ImportError as e: warnings.warn( "Not possible to use tensorboard for pytorch. Reason: " + e.msg, RuntimeWarning, ) tb_log = False global_step = 0 for epoch_count in range(num_epochs + 1): # Evaluate at beginning of epoch. if self._should_eval(): self.evaluate_all( epoch_count, num_epochs, tproblem, train_losses, valid_losses, test_losses, train_accuracies, valid_accuracies, test_accuracies, ) # COCKPIT: Log already computed quantities # cockpit.log( global_step, epoch_count, train_losses[-1], valid_losses[-1], test_losses[-1], train_accuracies[-1], valid_accuracies[-1], test_accuracies[-1], opt.param_groups[0]["lr"], ) # Break from train loop after the last round of evaluation if epoch_count == num_epochs: break # Training # # set to training mode tproblem.train_init_op() batch_count = 0 while True: try: opt.zero_grad() batch_loss, _ = tproblem.get_batch_loss_and_accuracy( reduction="mean") info = { "batch_size": self._extract_batch_size(batch_loss), "individual_losses": self._extract_individual_losses(batch_loss, ), "loss": batch_loss, "optimizer": opt, } # COCKPIT: Use necessary BackPACK extensions and track # with cockpit(global_step, info=info): batch_loss.backward( create_graph=cockpit.create_graph(global_step)) if plot_schedule(global_step): plotter.plot( cockpit, savedir=logpath, show_plot=training_params["show_plots"], save_plot=training_params["save_plots"], savename_append="__epoch__" + str(epoch_count).zfill(len(str(num_epochs))) + "__global_step__" + str(global_step).zfill(6), ) opt.step() if batch_count % train_log_interval == 0: minibatch_train_losses.append(batch_loss.item()) if print_train_iter: print("Epoch {0:d}, step {1:d}: loss {2:g}".format( epoch_count, batch_count, batch_loss)) if tb_log: summary_writer.add_scalar("loss", batch_loss.item(), global_step) batch_count += 1 global_step += 1 self._maybe_stop_iteration(global_step, batch_count) except StopIteration: break # Next step in LR Schedule scheduler.step() # COCKPIT: Write to file and optionally plot after last epoch # cockpit.write(logpath) if self._enable_plotting: plotter.plot( cockpit, savedir=logpath, show_plot=training_params["show_plots"], save_plot=training_params["save_final_plot"], ) if training_params["save_animation"]: plotter.build_animation(logpath) if tb_log: summary_writer.close() # Put results into output dictionary. output = { "train_losses": train_losses, "valid_losses": valid_losses, "test_losses": test_losses, "minibatch_train_losses": minibatch_train_losses, "train_accuracies": train_accuracies, "valid_accuracies": valid_accuracies, "test_accuracies": test_accuracies, } return output
def main(): global opt loss_rec = np.zeros((opt.folds, 100)) acc_rec = np.zeros((opt.folds, 100)) #loss_rec = np.load('acc_train.npy') #acc_rec = np.load('acc.npy') for iteration in range(opt.folds): train_dataset = mnist_Dataset(num_of_cross=iteration) print('number of train samples is: {0}'.format(len(train_dataset))) print('finished loading data') if opt.manualSeed is None: opt.manualSeed = random.randint(1, 10000) if torch.cuda.is_available() and not opt.cuda: print( "WARNING: You have a CUDA device, so you should probably run with \"cuda: True\"" ) torch.manual_seed(opt.manualSeed) else: if int(opt.ngpu) == 1: print('so we use 1 gpu to training') print('setting gpu on gpuid {0}'.format(opt.gpu_id)) if opt.cuda: os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id torch.cuda.manual_seed(opt.manualSeed) cudnn.benchmark = True print('Random Seed: {0}'.format(opt.manualSeed)) # train data loader train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batchSize, shuffle=True, num_workers=int( opt.workers)) # create model model = mnist_model.cat_and_dog_resnet() if opt.init_model != '': print('loading pretrained model from {0}'.format(opt.init_model)) model.load_state_dict(torch.load(opt.init_model)) # Contrastive Loss #criterion = mnist_model.StableBCELoss() criterion = nn.CrossEntropyLoss() if opt.cuda: print('shift model and criterion to GPU .. ') model = model.cuda() criterion = criterion.cuda() # optimizer # optimizer = optim.SGD(model.parameters(), lr=opt.lr, # momentum=opt.momentum, # weight_decay=opt.weight_decay) optimizer = optim.Adam(model.parameters(), lr=opt.lr) # optimizer = optim.SGD(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay, momentum=opt.momentum) # optimizer = optim.Adadelta(params=model.parameters(), lr=opt.lr) # adjust learning rate every lr_decay_epoch lambda_lr = lambda epoch: opt.lr_decay**( (epoch + 1) // opt.lr_decay_epoch) # poly policy scheduler = LR_Policy(optimizer, lambda_lr) resume_epoch = 0 acc = test(model, opt, iteration) acc_rec[iteration][0] = acc acc = test(model, opt, iteration, Training=True) loss_rec[iteration][0] = acc for epoch in range(resume_epoch, opt.max_epochs): ################################# # train for one epoch ################################# #accuracy = test(model, opt, epoch) train(train_loader, model, criterion, optimizer, iteration, opt, epoch) scheduler.step() ################################## # save checkpoints ################################## # save model every 10 epochs accuracy = test(model, opt, iteration) acc_rec[iteration][epoch + 1] = accuracy np.save('acc.npy', acc_rec) accuracy = test(model, opt, iteration, Training=True) loss_rec[iteration][epoch + 1] = accuracy np.save('acc_train.npy', loss_rec) if ((epoch + 1) % opt.epoch_save) == 0: path_checkpoint = '{0}/{1}_{3}_epoch{2}.pth'.format( opt.checkpoint_folder, opt.prefix, epoch + 1, iteration) utils.save_checkpoint(model.state_dict(), path_checkpoint)
def learn(device, env, nenv, seed, number_timesteps, network, optimizer, save_path, save_interval, ob_scale, lr, gamma, grad_norm, timesteps_per_batch, ent_coef, vf_coef, **kwargs): """ Paper: Mnih V, Badia A P, Mirza M, et al. Asynchronous methods for deep reinforcement learning[C]// International Conference on Machine Learning. 2016: 1928-1937. Parameters: ---------- gram_norm (float | None): grad norm timesteps_per_batch (int): number of steps per update ent_coef (float): policy entropy coefficient in the objective vf_coef (float): value function loss coefficient in the objective """ name = '{}_{}'.format(os.path.split(__file__)[-1][:-3], seed) logger = get_logger(name) policy = build_policy(env, network, estimate_value=True).to(device) optimizer = get_optimizer(optimizer, policy.parameters(), lr) number_timesteps = number_timesteps // nenv generator = _generate(device, env, policy, ob_scale, number_timesteps, gamma, timesteps_per_batch) max_iter = number_timesteps // timesteps_per_batch scheduler = LambdaLR(optimizer, lambda i_iter: 1 - i_iter / max_iter) total_timesteps = 0 infos = { k: deque(maxlen=100) for k in ['eplenmean', 'eprewmean', 'pgloss', 'v', 'entropy'] } start_ts = time.time() for n_iter in range(1, max_iter + 1): scheduler.step() batch = generator.__next__() b_o, b_a, b_r, b_v_old, info = batch for d in info: infos['eplenmean'].append(d['l']) infos['eprewmean'].append(d['r']) total_timesteps += b_o[0].size(0) # calculate advantange b_logits, b_v = policy(b_o) b_v = b_v[:, 0] dist = torch.distributions.Categorical(logits=b_logits) entropy = dist.entropy().mean() b_logp = dist.log_prob(b_a) adv = b_r - b_v_old # update policy vloss = (b_v - b_r).pow(2).mean() pgloss = -(adv * b_logp).mean() loss = pgloss + vf_coef * vloss - ent_coef * entropy optimizer.zero_grad() loss.backward() if grad_norm is not None: nn.utils.clip_grad_norm_(policy.parameters(), grad_norm) optimizer.step() # record logs infos['pgloss'].append(pgloss.item()) infos['v'].append(vloss.item()) infos['entropy'].append(entropy.item()) logger.info('{} Iter {} {}'.format('=' * 10, n_iter, '=' * 10)) fps = int(total_timesteps / (time.time() - start_ts)) logger.info('Total timesteps {} FPS {}'.format(total_timesteps, fps)) for k, v in infos.items(): v = (sum(v) / len(v)) if v else float('nan') logger.info('{}: {:.6f}'.format(k, v)) if save_interval and n_iter % save_interval == 0: torch.save([policy.state_dict(), optimizer.state_dict()], os.path.join(save_path, '{}.{}'.format(name, n_iter)))
def lr_range_test(self, data_loader, end_lr, num_iter=100, step_mode='exp', alpha=0.05, ax=None): # Since the test updates both model and optimizer we need to store # their initial states to restore them in the end previous_states = { 'model': deepcopy(self.model.state_dict()), 'optimizer': deepcopy(self.optimizer.state_dict()) } # Retrieves the learning rate set in the optimizer start_lr = self.optimizer.state_dict()['param_groups'][0]['lr'] # Builds a custom function and corresponding scheduler lr_fn = make_lr_fn(start_lr, end_lr, num_iter) scheduler = LambdaLR(self.optimizer, lr_lambda=lr_fn) # Variables for tracking results and iterations tracking = {'loss': [], 'lr': []} iteration = 0 # If there are more iterations than mini-batches in the data loader, # it will have to loop over it more than once while (iteration < num_iter): # That's the typical mini-batch inner loop for x_batch, y_batch in data_loader: x_batch = x_batch.to(self.device) y_batch = y_batch.to(self.device) # Step 1 yhat = self.model(x_batch) # Step 2 loss = self.loss_fn(yhat, y_batch) # Step 3 loss.backward() # Here we keep track of the losses (smoothed) # and the learning rates tracking['lr'].append(scheduler.get_last_lr()[0]) if iteration == 0: tracking['loss'].append(loss.item()) else: prev_loss = tracking['loss'][-1] smoothed_loss = alpha * loss.item() + (1 - alpha) * prev_loss tracking['loss'].append(smoothed_loss) iteration += 1 # Number of iterations reached if iteration == num_iter: break # Step 4 self.optimizer.step() scheduler.step() self.optimizer.zero_grad() # Restores the original states self.optimizer.load_state_dict(previous_states['optimizer']) self.model.load_state_dict(previous_states['model']) if ax is None: fig, ax = plt.subplots(1, 1, figsize=(6, 4)) else: fig = ax.get_figure() ax.plot(tracking['lr'], tracking['loss']) if step_mode == 'exp': ax.set_xscale('log') ax.set_xlabel('Learning Rate') ax.set_ylabel('Loss') fig.tight_layout() return tracking, fig
optimizer = SGD(model.parameters(), lr=0.03, momentum=0.9, weight_decay=5e-4) lr_scheduler = LambdaLR(optimizer, lr_lambda=lambda i: 0.5 * (math.cos(i * math.pi / epochs) + 1)) #c = len(memory_data.classes) c = 2 results = {'train_loss': [], 'test_acc@1': [], 'test_acc@5': []} save_name_pre = '{}_{}_{}_{}'.format(feature_dim, k, batch_size, epochs) if not os.path.exists('results'): os.mkdir('results') best_acc = 0.0 # training loop for epoch in range(1, epochs + 1): train_loss = train(model, train_loader, optimizer) results['train_loss'].append(train_loss) lr_scheduler.step() #test_acc_1, test_acc_5 = test(model, memory_loader, test_loader) #results['test_acc@1'].append(test_acc_1) #results['test_acc@5'].append(test_acc_5) # save statistics #data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1)) #data_frame.to_csv('results/{}_statistics.csv'.format(save_name_pre), index_label='epoch') #if test_acc_1 > best_acc: #best_acc = test_acc_1 torch.save(model.state_dict(), 'results/{}_model.pth'.format(save_name_pre))
def train(model, tokenizer, train_data, valid_data, args): model.train() train_dataset = TextDataset(train_data) train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=args.train_batch_size, num_workers=args.num_workers, collate_fn=lambda x: collate_fn_bert( x, tokenizer, args.max_seq_length)) valid_dataset = TextDataset(valid_data) valid_dataloader = DataLoader(valid_dataset, sampler=SequentialSampler(valid_dataset), batch_size=args.eval_batch_size, num_workers=args.num_workers, collate_fn=lambda x: collate_fn_bert( x, tokenizer, args.max_seq_length)) valid_noisy = [x['noisy'] for x in valid_data] valid_clean = [x['clean'] for x in valid_data] epochs = (args.max_steps - 1) // len(train_dataloader) + 1 # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, # betas=eval(args.adam_betas), eps=args.eps, # weight_decay=args.weight_decay) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) lr_lambda = lambda x: x / args.num_warmup_steps if x <= args.num_warmup_steps else ( x / args.num_warmup_steps)**-0.5 scheduler = LambdaLR(optimizer, lr_lambda) step = 0 best_val_gleu = -float("inf") meter = Meter() for epoch in range(1, epochs + 1): for batch in train_dataloader: step += 1 batch = tuple(t.to(args.device) for t in batch) noise_input_ids, clean_input_ids, noise_mask, clean_mask = batch #print("noise shape: {}, clean shape: {}".format(noise_input_ids.shape, clean_input_ids.shape)) outputs = model(noise_input_ids, labels=clean_input_ids, attention_mask=noise_mask) loss = outputs[0] predict_score = outputs[1] bsz = clean_input_ids.size(0) items = [loss.data.item(), bsz, clean_mask.sum().item()] #print("items: ", items) meter.add(*items) loss.backward() if args.max_grad_norm > 0: nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() model.zero_grad() scheduler.step() if step % args.log_interval == 0: lr = scheduler.get_lr()[0] loss_sent, loss_token = meter.average() logger.info( f' [{step:5d}] lr {lr:.6f} | {meter.print_str(True)}') nsml.report(step=step, scope=locals(), summary=True, train__lr=lr, train__loss_sent=loss_sent, train__token_ppl=math.exp(loss_token)) meter.init() if step % args.eval_interval == 0: start_eval = time.time() (val_loss, val_loss_token), valid_str = evaluate_kcBert( model, valid_dataloader, args) prediction = correct_kcBert(model, tokenizer, valid_noisy, args, length_limit=0.1) val_em = em(prediction, valid_clean) cnt = 0 for noisy, pred, clean in zip(valid_noisy, prediction, valid_clean): print(f'[{noisy}], [{pred}], [{clean}]') # 10개만 출력하기 cnt += 1 if cnt == 20: break # print("len of prediction: {}, len of valid_clean: {}", len(prediction), len(valid_clean)) val_gleu = gleu(prediction, valid_clean) logger.info('-' * 89) logger.info( f' [{step:6d}] valid | {valid_str} | em {val_em:5.2f} | gleu {val_gleu:5.2f}' ) logger.info('-' * 89) nsml.report(step=step, scope=locals(), summary=True, valid__loss_sent=val_loss, valid__token_ppl=math.exp(val_loss_token), valid__em=val_em, valid__gleu=val_gleu) if val_gleu > best_val_gleu: best_val_gleu = val_gleu nsml.save("best") meter.start += time.time() - start_eval if step >= args.max_steps: break if step >= args.max_steps: break
class Model: def __init__(self, device, num_steps): # in and out channels # for the generator: a, b = 1, 3 def weights_init(m): if isinstance(m, nn.Conv2d): init.normal_(m.weight, std=0.02) if m.bias is not None: init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d): init.ones_(m.weight) init.zeros_(m.bias) G = Generator(a, b).train() self.G = G.apply(weights_init).to(device) # it turns out that this is important init.normal_(self.G.end[0].weight, std=1e-4) def lambda_rule(i): decay = num_steps // 4 m = 1.0 if i < decay else 1.0 - (i - decay) / (num_steps - decay) return max(m, 1e-3) self.optimizer = optim.Adam(self.G.parameters(), lr=2e-4, betas=(0.9, 0.999)) self.scheduler = LambdaLR(self.optimizer, lr_lambda=lambda_rule) self.cp_loss = CPLoss() self.gp_loss = GPLoss() if USE_FLOAT16: self.G, self.optimizer = amp.initialize(self.G, self.optimizer, opt_level='O2') # a copy for exponential moving average self.G_ema = copy.deepcopy(self.G) def train_step(self, A, B): """ The input tensors represent images with pixel values in [0, 1] range. Arguments: A: a float tensor with shape [n, a, h, w]. B: a float tensor with shape [n, b, h, w]. Returns: a dict with float numbers. """ self.optimizer.zero_grad() B_restored = self.G(A) # it has shape [n, b, h, w] cp_loss = self.cp_loss(B_restored, B) gp_loss = self.gp_loss(B_restored, B) reconstruction_loss = cp_loss + gp_loss if not USE_FLOAT16: reconstruction_loss.backward() else: with amp.scale_loss(reconstruction_loss, self.optimizer) as loss_scaled: loss_scaled.backward() self.optimizer.step() self.scheduler.step() # running average of weights accumulate(self.G_ema, self.G) loss_dict = { 'total_loss': reconstruction_loss.item(), 'cp_loss': cp_loss.item(), 'gp_loss': gp_loss.item() } return loss_dict def save_model(self, model_path): torch.save(self.G.state_dict(), model_path + '_generator.pth') torch.save(self.G_ema.state_dict(), model_path + '_generator_ema.pth')
class TriviaQA(pl.LightningModule): def __init__(self, args): super(TriviaQA, self).__init__() self.args = args self.hparams = args self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base') self.tokenizer.model_max_length = self.args.max_seq_len self.model = self.load_model() self.num_labels = 2 self.qa_outputs = torch.nn.Linear(self.model.config.hidden_size, self.num_labels) self.train_dataloader_object = self.val_dataloader_object = self.test_dataloader_object = None def load_model(self): model = Longformer.from_pretrained(self.args.model_path) for layer in model.encoder.layer: layer.attention.self.attention_mode = self.args.attention_mode self.args.attention_window = layer.attention.self.attention_window print("Loaded model with config:") print(model.config) for p in model.parameters(): p.requires_grad_(True) model.train() return model def forward(self, input_ids, attention_mask, segment_ids, start_positions, end_positions): question_end_index = self._get_question_end_index(input_ids) # Each batch is one document, and each row of the batch is a chunck of the document. # Make sure all rows have the same question length. assert (question_end_index[0].float() == question_end_index.float().mean()).item() # local attention everywhere attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) # global attention for the question tokens attention_mask[:, :question_end_index.item()] = 2 # sliding_chunks implemenation of selfattention requires that seqlen is multiple of window size input_ids, attention_mask = pad_to_window_size( input_ids, attention_mask, self.args.attention_window, self.tokenizer.pad_token_id) sequence_output = self.model(input_ids, attention_mask=attention_mask)[0] # The pretrained TriviaQA model wasn't trained with padding, so remove padding tokens # before computing loss and decoding. padding_len = input_ids[0].eq(self.tokenizer.pad_token_id).sum() if padding_len > 0: sequence_output = sequence_output[:, :-padding_len] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) outputs = ( start_logits, end_logits, ) if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) if not self.args.regular_softmax_loss: # loss function suggested in section 2.2 here https://arxiv.org/pdf/1710.10723.pdf # NOTE: this returns sum of losses, not mean, so loss won't be normalized across different batch sizes # but batch size is always 1, so this is not a problem start_loss = self.or_softmax_cross_entropy_loss_one_doc( start_logits, start_positions, ignore_index=-1) end_loss = self.or_softmax_cross_entropy_loss_one_doc( end_logits, end_positions, ignore_index=-1) else: loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-1) start_positions = start_positions[:, 0:1] end_positions = end_positions[:, 0:1] start_loss = loss_fct(start_logits, start_positions[:, 0]) end_loss = loss_fct(end_logits, end_positions[:, 0]) total_loss = (start_loss + end_loss) / 2 outputs = (total_loss, ) + outputs return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) def or_softmax_cross_entropy_loss_one_doc(self, logits, target, ignore_index=-1, dim=-1): """loss function suggested in section 2.2 here https://arxiv.org/pdf/1710.10723.pdf""" assert logits.ndim == 2 assert target.ndim == 2 assert logits.size(0) == target.size(0) # with regular CrossEntropyLoss, the numerator is only one of the logits specified by the target # here, the numerator is the sum of a few potential targets, where some of them is the correct answer # compute a target mask target_mask = target == ignore_index # replaces ignore_index with 0, so `gather` will select logit at index 0 for the msked targets masked_target = target * (1 - target_mask.long()) # gather logits gathered_logits = logits.gather(dim=dim, index=masked_target) # Apply the mask to gathered_logits. Use a mask of -inf because exp(-inf) = 0 gathered_logits[target_mask] = float('-inf') # each batch is one example gathered_logits = gathered_logits.view(1, -1) logits = logits.view(1, -1) # numerator = log(sum(exp(gathered logits))) log_score = torch.logsumexp(gathered_logits, dim=dim, keepdim=False) # denominator = log(sum(exp(logits))) log_norm = torch.logsumexp(logits, dim=dim, keepdim=False) # compute the loss loss = -(log_score - log_norm) # some of the examples might have a loss of `inf` when `target` is all `ignore_index`. # remove those from the loss before computing the sum. Use sum instead of mean because # it is easier to compute return loss[~torch.isinf(loss)].sum() def training_step(self, batch, batch_nb): input_ids, input_mask, segment_ids, subword_starts, subword_ends, qids, aliases = batch output = self.forward(input_ids, input_mask, segment_ids, subword_starts, subword_ends) loss = output[0] lr = loss.new_zeros( 1) + self.trainer.optimizers[0].param_groups[0]['lr'] tensorboard_logs = { 'train_loss': loss, 'lr': lr, 'input_size': input_ids.numel(), 'mem': torch.cuda.memory_allocated(input_ids.device) / 1024**3 } return {'loss': loss, 'log': tensorboard_logs} def validation_step(self, batch, batch_nb): input_ids, input_mask, segment_ids, subword_starts, subword_ends, qids, aliases = batch output = self.forward(input_ids, input_mask, segment_ids, subword_starts, subword_ends) loss, start_logits, end_logits = output[:3] answers = self.decode(input_ids, start_logits, end_logits) # each batch is one document answers = sorted(answers, key=lambda x: x['score'], reverse=True)[0:1] qids = [qids] aliases = [aliases] f1_scores = [ evaluation_utils.metric_max_over_ground_truths( evaluation_utils.f1_score, answer['text'], aliase_list) for answer, aliase_list in zip(answers, aliases) ] # TODO: if slow, skip em_scores, and use (f1_score == 1.0) instead em_scores = [ evaluation_utils.metric_max_over_ground_truths( evaluation_utils.exact_match_score, answer['text'], aliase_list) for answer, aliase_list in zip(answers, aliases) ] answer_scores = [answer['score'] for answer in answers] # start_logit + end_logit assert len(answer_scores) == len(f1_scores) == len(em_scores) == len( qids) == len(aliases) == 1 # TODO: delete this metric pred_subword_starts = start_logits.argmax(dim=1) pred_subword_ends = end_logits.argmax(dim=1) exact_match = (subword_ends[:, 0].squeeze(dim=-1) == pred_subword_ends).float() * \ (subword_starts[:, 0].squeeze(dim=-1) == pred_subword_starts).float() return { 'vloss': loss, 'vem': exact_match.mean(), 'qids': qids, 'answer_scores': answer_scores, 'f1': f1_scores, 'em': em_scores } def _get_question_end_index(self, input_ids): eos_token_indices = ( input_ids == self.tokenizer.eos_token_id).nonzero() assert eos_token_indices.ndim == 2 assert eos_token_indices.size(0) == 2 * input_ids.size(0) assert eos_token_indices.size(1) == 2 return eos_token_indices.view(input_ids.size(0), 2, 2)[:, 0, 1] def decode(self, input_ids, start_logits, end_logits): # find beginning of document question_end_index = self._get_question_end_index(input_ids) # bsz x seqlen => bsz x n_best_size start_logits_indices = start_logits.topk(k=self.args.n_best_size, dim=-1).indices end_logits_indices = end_logits.topk(k=self.args.n_best_size, dim=-1).indices answers = [] # This loop can't be vectorized, so loop over each example in the batch separetly for i in range(start_logits_indices.size(0)): # bsz potential_answers = [] for start_logit_index in start_logits_indices[i]: # n_best_size for end_logit_index in end_logits_indices[i]: # n_best_size if start_logit_index <= question_end_index[i]: continue if end_logit_index <= question_end_index[i]: continue if start_logit_index > end_logit_index: continue answer_len = end_logit_index - start_logit_index + 1 if answer_len > self.args.max_answer_length: continue potential_answers.append({ 'start': start_logit_index, 'end': end_logit_index, 'start_logit': start_logits[i][start_logit_index].item(), 'end_logit': end_logits[i][end_logit_index].item() }) sorted_answers = sorted(potential_answers, key=lambda x: (x['start_logit'] + x['end_logit']), reverse=True) if len(sorted_answers) == 0: answers.append({'text': 'NoAnswerFound', 'score': -1000000}) else: answer = sorted_answers[0] answer_token_ids = input_ids[i, answer['start']:answer['end'] + 1] answer_tokens = self.tokenizer.convert_ids_to_tokens( answer_token_ids.tolist()) text = self.tokenizer.convert_tokens_to_string(answer_tokens) score = answer['start_logit'] + answer['end_logit'] answers.append({'text': text, 'score': score}) return answers def sync_list_across_gpus(self, l, device, dtype): l_tensor = torch.tensor(l, device=device, dtype=dtype) gather_l_tensor = [ torch.ones_like(l_tensor) for _ in range(self.trainer.world_size) ] torch.distributed.all_gather(gather_l_tensor, l_tensor) return torch.cat(gather_l_tensor).tolist() def validation_end(self, outputs): avg_loss = torch.stack([x['vloss'] for x in outputs]).mean() avg_em = torch.stack([x['vem'] for x in outputs]).mean() string_qids = [item for sublist in outputs for item in sublist['qids']] int_qids = [ self.val_dataloader_object.dataset.val_qid_string_to_int_map[qid] for qid in string_qids ] answer_scores = [ item for sublist in outputs for item in sublist['answer_scores'] ] f1_scores = [item for sublist in outputs for item in sublist['f1']] em_scores = [item for sublist in outputs for item in sublist['em']] print( f'before sync --> sizes: {len(int_qids)}, {len(answer_scores)}, {len(f1_scores)}, {len(em_scores)}' ) if self.trainer.use_ddp: torch.distributed.all_reduce(avg_loss, op=torch.distributed.ReduceOp.SUM) avg_loss /= self.trainer.world_size torch.distributed.all_reduce(avg_em, op=torch.distributed.ReduceOp.SUM) avg_em /= self.trainer.world_size int_qids = self.sync_list_across_gpus(int_qids, avg_loss.device, torch.int) answer_scores = self.sync_list_across_gpus(answer_scores, avg_loss.device, torch.float) f1_scores = self.sync_list_across_gpus(f1_scores, avg_loss.device, torch.float) em_scores = self.sync_list_across_gpus(em_scores, avg_loss.device, torch.int) print( f'after sync --> sizes: {len(int_qids)}, {len(answer_scores)}, {len(f1_scores)}, {len(em_scores)}' ) # Because of having multiple documents per questions, some questions might have multiple corresponding answers # Here, we only keep the answer with the highest answer_score qa_with_duplicates = defaultdict(list) for qid, answer_score, f1_score, em_score in zip( int_qids, answer_scores, f1_scores, em_scores): qa_with_duplicates[qid].append({ 'answer_score': answer_score, 'f1': f1_score, 'em': em_score }) f1_scores = [] em_scores = [] for qid, answer_metrics in qa_with_duplicates.items(): top_answer = sorted(answer_metrics, key=lambda x: x['answer_score'], reverse=True)[0] f1_scores.append(top_answer['f1']) em_scores.append(top_answer['em']) avg_val_f1 = sum(f1_scores) / len(f1_scores) avg_val_em = sum(em_scores) / len(em_scores) logs = { 'val_loss': avg_loss, 'val_em': avg_em, 'avg_val_f1': avg_val_f1, 'avg_val_em': avg_val_em } return {'avg_val_loss': avg_loss, 'log': logs, 'progress_bar': logs} def test_step(self, batch, batch_nb): input_ids, input_mask, segment_ids, subword_starts, subword_ends, qids, aliases = batch output = self.forward(input_ids, input_mask, segment_ids, subword_starts, subword_ends) loss, start_logits, end_logits = output[:3] answers = self.decode(input_ids, start_logits, end_logits) # each batch is one document answers = sorted(answers, key=lambda x: x['score'], reverse=True)[0:1] qids = [qids] assert len(answers) == len(qids) return {'qids': qids, 'answers': answers} def test_end(self, outputs): qids = [item for sublist in outputs for item in sublist['qids']] answers = [item for sublist in outputs for item in sublist['answers']] qa_with_duplicates = defaultdict(list) for qid, answer in zip(qids, answers): qa_with_duplicates[qid].append({ 'answer_score': answer['score'], 'answer_text': answer['text'], }) qid_to_answer_text = {} for qid, answer_metrics in qa_with_duplicates.items(): top_answer = sorted(answer_metrics, key=lambda x: x['answer_score'], reverse=True)[0] qid_to_answer_text[qid] = top_answer['answer_text'] with open('predictions.json', 'w') as f: json.dump(qid_to_answer_text, f) return {'count': len(qid_to_answer_text)} def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None): optimizer.step() optimizer.zero_grad() self.scheduler.step(self.global_step) def configure_optimizers(self): def lr_lambda(current_step): if current_step < self.args.warmup: return float(current_step) / float(max(1, self.args.warmup)) return max( 0.0, float(self.args.steps - current_step) / float(max(1, self.args.steps - self.args.warmup))) optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr) self.scheduler = LambdaLR( optimizer, lr_lambda, last_epoch=-1 ) # scheduler is not saved in the checkpoint, but global_step is, which is enough to restart self.scheduler.step(self.global_step) return optimizer @pl.data_loader def train_dataloader(self): if self.train_dataloader_object is not None: return self.train_dataloader_object dataset = TriviaQADataset( file_path=self.args.train_dataset, tokenizer=self.tokenizer, max_seq_len=self.args.max_seq_len, max_doc_len=self.args.max_doc_len, doc_stride=self.args.doc_stride, max_num_answers=self.args.max_num_answers, max_question_len=self.args.max_question_len, ignore_seq_with_no_answers=self.args.ignore_seq_with_no_answers) sampler = torch.utils.data.distributed.DistributedSampler( dataset) if self.trainer.use_ddp else None dl = DataLoader(dataset, batch_size=1, shuffle=(sampler is None), num_workers=self.args.num_workers, sampler=sampler, collate_fn=TriviaQADataset.collate_one_doc_and_lists) self.train_dataloader_object = dl return self.train_dataloader_object @pl.data_loader def val_dataloader(self): if self.val_dataloader_object is not None: return self.val_dataloader_object dataset = TriviaQADataset(file_path=self.args.dev_dataset, tokenizer=self.tokenizer, max_seq_len=self.args.max_seq_len, max_doc_len=self.args.max_doc_len, doc_stride=self.args.doc_stride, max_num_answers=self.args.max_num_answers, max_question_len=self.args.max_question_len, ignore_seq_with_no_answers=False ) # evaluation data should keep all examples sampler = torch.utils.data.distributed.DistributedSampler( dataset) if self.trainer.use_ddp else None dl = DataLoader(dataset, batch_size=1, shuffle=(sampler is None), num_workers=self.args.num_workers, sampler=sampler, collate_fn=TriviaQADataset.collate_one_doc_and_lists) self.val_dataloader_object = dl return self.val_dataloader_object @pl.data_loader def test_dataloader(self): if self.test_dataloader_object is not None: return self.test_dataloader_object dataset = TriviaQADataset(file_path=self.args.dev_dataset, tokenizer=self.tokenizer, max_seq_len=self.args.max_seq_len, max_doc_len=self.args.max_doc_len, doc_stride=self.args.doc_stride, max_num_answers=self.args.max_num_answers, max_question_len=self.args.max_question_len, ignore_seq_with_no_answers=False ) # evaluation data should keep all examples dl = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.args.num_workers, sampler=None, collate_fn=TriviaQADataset.collate_one_doc_and_lists) self.test_dataloader_object = dl return self.test_dataloader_object def configure_ddp(self, model, device_ids): model = LightningDistributedDataParallel(model, device_ids=device_ids, find_unused_parameters=True) return model @staticmethod def add_model_specific_args(parser, root_dir): parser.add_argument("--save_dir", type=str, default='triviaqa') parser.add_argument("--save_prefix", type=str, required=True) parser.add_argument("--train_dataset", type=str, required=False, help="Path to the training squad-format") parser.add_argument("--dev_dataset", type=str, required=True, help="Path to the dev squad-format") parser.add_argument("--batch_size", type=int, default=8, help="Batch size") parser.add_argument( "--gpus", type=str, default='0', help= "Comma separated list of gpus. Default is gpu 0. To use CPU, use --gpus " " ") parser.add_argument("--warmup", type=int, default=200, help="Number of warmup steps") parser.add_argument("--lr", type=float, default=0.0001, help="Maximum learning rate") parser.add_argument( "--val_every", type=float, default=0.2, help="Number of training steps between validations") parser.add_argument("--val_percent_check", default=1.00, type=float, help='Percent of validation data used') parser.add_argument("--num_workers", type=int, default=4, help="Number of data loader workers") parser.add_argument("--seed", type=int, default=1234, help="Seed") parser.add_argument("--epochs", type=int, default=30, help="Number of epochs") parser.add_argument( "--max_seq_len", type=int, default=4096, help="Maximum length of seq passed to the transformer model") parser.add_argument( "--max_doc_len", type=int, default=4096, help="Maximum number of wordpieces of the input document") parser.add_argument( "--max_num_answers", type=int, default=64, help="Maximum number of answer spans per document (64 => 94%)") parser.add_argument("--max_question_len", type=int, default=55, help="Maximum length of the question") parser.add_argument( "--doc_stride", type=int, default=-1, help= "Overlap between document chunks. Use -1 to only use the first chunk" ) parser.add_argument( "--ignore_seq_with_no_answers", action='store_true', help= "each example should have at least one answer. Default is False") parser.add_argument("--disable_checkpointing", action='store_true', help="No logging or checkpointing") parser.add_argument( "--n_best_size", type=int, default=20, help="Number of answer candidates. Used at decoding time") parser.add_argument( "--max_answer_length", type=int, default=30, help="maximum num of wordpieces/answer. Used at decoding time") parser.add_argument( "--regular_softmax_loss", action='store_true', help= "IF true, use regular softmax. Default is using ORed softmax loss") parser.add_argument("--test", action='store_true', help="Test only, no training") parser.add_argument("--model_path", type=str, required=True, help="Path to the checkpoint directory") parser.add_argument("--no_progress_bar", action='store_true', help="no progress bar. Good for printing") parser.add_argument( "--attention_mode", type=str, choices=['tvm', 'sliding_chunks'], default='sliding_chunks', help='Which implementation of selfattention to use') parser.add_argument( "--fp32", action='store_true', help="default is fp16. Use --fp32 to switch to fp32") return parser
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, jmmd_loss: JointMultipleKernelMaximumMeanDiscrepancy, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') trans_losses = AverageMeter('Trans Loss', ':5.4f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() jmmd_loss.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) cls_loss = F.cross_entropy(y_s, labels_s) transfer_loss = jmmd_loss((f_s, F.softmax(y_s, dim=1)), (f_t, F.softmax(y_t, dim=1))) loss = cls_loss + transfer_loss * args.trade_off cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_t.size(0)) trans_losses.update(transfer_loss.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, domain_adv: DomainAdversarialLoss, class_weight_module: AutomaticUpdateClassWeightModule, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':5.2f') data_time = AverageMeter('Data', ':5.2f') losses = AverageMeter('Loss', ':6.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') domain_accs = AverageMeter('Domain Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') partial_classes_weights = AverageMeter('Partial Weight', ':3.1f') non_partial_classes_weights = AverageMeter('Non-partial Weight', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs, domain_accs, tgt_accs, partial_classes_weights, non_partial_classes_weights], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() domain_adv.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) cls_loss = F.cross_entropy(y_s, labels_s, class_weight_module.get_class_weight_for_cross_entropy_loss()) w_s, w_t = class_weight_module.get_class_weight_for_adversarial_loss(labels_s) transfer_loss = domain_adv(f_s, f_t, w_s, w_t) class_weight_module.step() partial_classes_weight, non_partial_classes_weight = class_weight_module.get_partial_classes_weight() domain_acc = domain_adv.domain_discriminator_accuracy loss = cls_loss + transfer_loss * args.trade_off cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) domain_accs.update(domain_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_s.size(0)) partial_classes_weights.update(partial_classes_weight.item(), x_s.size(0)) non_partial_classes_weights.update(non_partial_classes_weight.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(model: torch.nn.Module, train_dls: List[DataLoader], optimizer: torch.optim.Optimizer, scheduler: LambdaLR, validation_evaluator: MultiDatasetClassificationEvaluator, n_epochs: int, device: AnyStr, log_interval: int = 1, patience: int = 10, model_dir: str = "wandb_local", gradient_accumulation: int = 1, domain_name: str = ''): #best_loss = float('inf') best_f1 = 0.0 patience_counter = 0 epoch_counter = 0 total = sum(len(dl) for dl in train_dls) # Main loop while epoch_counter < n_epochs: dl_iters = [iter(dl) for dl in train_dls] dl_idx = list(range(len(dl_iters))) finished = [0] * len(dl_iters) i = 0 with tqdm(total=total, desc="Training") as pbar: while sum(finished) < len(dl_iters): random.shuffle(dl_idx) for d in dl_idx: domain_dl = dl_iters[d] batches = [] try: for j in range(gradient_accumulation): batches.append(next(domain_dl)) except StopIteration: finished[d] = 1 if len(batches) == 0: continue optimizer.zero_grad() for batch in batches: model.train() batch = tuple(t.to(device) for t in batch) input_ids = batch[0] masks = batch[1] labels = batch[2] # Testing with random domains to see if any effect #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device) domains = batch[3] loss, logits, alpha = model(input_ids, attention_mask=masks, domains=domains, labels=labels, ret_alpha=True) loss = loss.mean() / gradient_accumulation if i % log_interval == 0: # wandb.log({ # "Loss": loss.item(), # "alpha0": alpha[:,0].cpu(), # "alpha1": alpha[:, 1].cpu(), # "alpha2": alpha[:, 2].cpu(), # "alpha_shared": alpha[:, 3].cpu() # }) wandb.log({"Loss": loss.item()}) loss.backward() i += 1 pbar.update(1) optimizer.step() if scheduler is not None: scheduler.step() gc.collect() # Inline evaluation (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) print(f"Validation f1: {F1}") #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth') # Saving the best model and early stopping #if val_loss < best_loss: if F1 > best_f1: best_model = model.state_dict() #best_loss = val_loss best_f1 = F1 #wandb.run.summary['best_validation_loss'] = best_loss torch.save( model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth' ) patience_counter = 0 # Log to wandb wandb.log({ 'Validation accuracy': acc, 'Validation Precision': P, 'Validation Recall': R, 'Validation F1': F1, 'Validation loss': val_loss }) else: patience_counter += 1 # Stop training once we have lost patience if patience_counter == patience: break gc.collect() epoch_counter += 1
def main(): # Training settings and hyper-parameters parser = argparse.ArgumentParser( description='Data Source (Batch) Prediction for Cell Lines') # Dataset parameters ###################################################### # Pre-processing for dataframes parser.add_argument('--rnaseq_scaling', type=str, default='std', help='scaling method for RNA sequence', choices=SCALING_METHODS) # Feature usage and partitioning settings parser.add_argument('--rnaseq_feature_usage', type=str, default='combat', help='RNA sequence data used', choices=[ 'source_scale', 'combat', ]) parser.add_argument('--validation_ratio', type=float, default=0.2, help='ratio for validation dataset') # Network configuration ################################################### parser.add_argument('--layer_dim', type=int, default=256, help='dimension of layers for RNA sequence') parser.add_argument('--num_layers', type=int, default=4, help='number of layers for RNA sequence') # Training and validation parameters ###################################### parser.add_argument('--opt', type=str, default='SGD', help='optimizer for data source prediction', choices=['SGD', 'RMSprop', 'Adam']) parser.add_argument('--lr', type=float, default=1e-2, help='learning rate for data source prediction') # Starting epoch for validation parser.add_argument('--val_start_epoch', type=int, default=0, help='starting epoch for data source prediction') # Early stopping based on data source prediction accuracy parser.add_argument('--early_stop_patience', type=int, default=50, help='patience for early stopping based on data ' 'source prediction accuracy') # Global/shared training parameters parser.add_argument('--l2_regularization', type=float, default=0., help='L2 regularization for nn weights') parser.add_argument('--lr_decay_factor', type=float, default=0.98, help='decay factor for learning rate') parser.add_argument('--trn_batch_size', type=int, default=32, help='input batch size for training') parser.add_argument('--val_batch_size', type=int, default=256, help='input batch size for validation') parser.add_argument('--max_num_batches', type=int, default=10000, help='maximum number of batches per epoch') parser.add_argument('--max_num_epochs', type=int, default=1000, help='maximum number of epochs') # Miscellaneous settings ################################################## parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--rand_state', type=int, default=0, help='random state of numpy/sklearn/pytorch') args = parser.parse_args() print('Training Arguments:\n' + json.dumps(vars(args), indent=4)) # Setting up random seed for reproducible and deterministic results seed_random_state(args.rand_state) # Computation device config (cuda or cpu) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') # Data loaders for training/validation #################################### dataloader_kwargs = { 'timeout': 1, 'shuffle': 'True', # 'num_workers': multiprocessing.cpu_count() if use_cuda else 0, 'num_workers': NUM_WORKER if use_cuda else 0, 'pin_memory': True if use_cuda else False, } # Drug response dataloaders for training/validation cl_clf_dataset_kwargs = { 'data_root': DATA_ROOT, 'rand_state': args.rand_state, 'summary': False, 'int_dtype': np.int8, 'float_dtype': np.float16, 'output_dtype': np.float32, 'rnaseq_scaling': args.rnaseq_scaling, 'predict_target': 'source', 'rnaseq_feature_usage': args.rnaseq_feature_usage, 'validation_ratio': args.validation_ratio, } cl_clf_trn_loader = torch.utils.data.DataLoader( CLClassDataset(training=True, **cl_clf_dataset_kwargs), batch_size=args.trn_batch_size, **dataloader_kwargs) cl_clf_val_loader = torch.utils.data.DataLoader( CLClassDataset(training=False, **cl_clf_dataset_kwargs), batch_size=args.val_batch_size, **dataloader_kwargs) # Constructing and initializing neural networks ########################### net = nn.Sequential() prev_dim = cl_clf_trn_loader.dataset.rnaseq_dim for label in ['site', 'type', 'category']: prev_dim += len(get_label_dict(DATA_ROOT, '%s_dict.txt' % label)) # net.add_module('dense_%d' % 0, nn.Linear(prev_dim, args.layer_dim)) for i in range(args.num_layers): # net.add_module('residual_block_%d' % i, # ResBlock(layer_dim=args.layer_dim, # num_layers=2, # dropout=0.)) net.add_module('dense_%d' % i, nn.Linear(prev_dim, args.layer_dim)) net.add_module('dropout_%d' % i, nn.Dropout(0.2)) prev_dim = args.layer_dim net.add_module('relu_%d' % i, nn.ReLU()) num_data_src = len(get_label_dict(DATA_ROOT, 'data_src_dict.txt')) net.add_module('dense', nn.Linear(args.layer_dim, num_data_src)) net.add_module('logsoftmax', nn.LogSoftmax(dim=1)) net.apply(basic_weight_init) net.to(device) print(net) # Optimizers, learning rate decay, and miscellaneous ###################### opt = get_optimizer(opt_type=args.opt, networks=net, learning_rate=args.lr, l2_regularization=args.l2_regularization) lr_decay = LambdaLR(optimizer=opt, lr_lambda=lambda e: args.lr_decay_factor**e) # Training/validation loops ############################################### val_acc = [] best_acc = 0. patience = 0 start_time = time.time() for epoch in range(args.max_num_epochs): print('=' * 80 + '\nTraining Epoch %3i:' % (epoch + 1)) epoch_start_time = time.time() lr_decay.step(epoch) # Training loop ####################################################### net.train() for batch_idx, (rnaseq, data_src, cl_site, cl_type, cl_category) \ in enumerate(cl_clf_trn_loader): if batch_idx >= args.max_num_batches: break rnaseq, data_src, cl_site, cl_type, cl_category = \ rnaseq.to(device), data_src.to(device), cl_site.to(device), \ cl_type.to(device), cl_category.to(device) net.zero_grad() out_data_src = net( torch.cat((rnaseq, cl_site, cl_type, cl_category), dim=1)) F.nll_loss(input=out_data_src, target=data_src).backward() opt.step() # Validation loop ##################################################### net.eval() correct_data_src = 0 with torch.no_grad(): for rnaseq, data_src, cl_site, cl_type, cl_category \ in cl_clf_val_loader: rnaseq, data_src, cl_site, cl_type, cl_category = \ rnaseq.to(device), data_src.to(device), \ cl_site.to(device), cl_type.to(device), \ cl_category.to(device) out_data_src = net( torch.cat((rnaseq, cl_site, cl_type, cl_category), dim=1)) pred_data_src = out_data_src.max(1, keepdim=True)[1] # print(data_src) # print(pred_data_src) correct_data_src += pred_data_src.eq( data_src.view_as(pred_data_src)).sum().item() data_src_acc = 100. * correct_data_src / len(cl_clf_val_loader.dataset) print( '\tCell Line Data Source (Batch) Prediction Accuracy: %5.2f%%; ' % data_src_acc) # Results recording and early stopping val_acc.append(data_src_acc) if data_src_acc > best_acc: patience = 0 best_acc = data_src_acc else: patience += 1 if patience >= args.early_stop_patience: print('Validation accuracy does not improve for %d epochs ... ' 'invoking early stopping.' % patience) break print('Epoch Running Time: %.1f Seconds.' % (time.time() - epoch_start_time)) print('Program Running Time: %.1f Seconds.' % (time.time() - start_time)) print('Best Cell Line Data Source (Batch) Prediction Accuracy: %5.2f%%; ' % np.amax(val_acc)) import matplotlib.pyplot as plt x = range(1, len(val_acc) + 1) plt.plot(x, val_acc) plt.xlabel('Epochs') plt.ylabel('Value Accuracy') plt.title('Value Accuracy over Training') plt.show()
def train(model, state, path, annotations, val_path, val_annotations, resize, max_size, jitter, batch_size, iterations, val_iterations, mixed_precision, lr, warmup, milestones, gamma, is_master=True, world=1, use_dali=True, verbose=True, metrics_url=None, logdir=None): 'Train the model on the given dataset' # Prepare model nn_model = model stride = model.stride model = convert_fixedbn_model(model) if torch.cuda.is_available(): model = model.cuda() # Setup optimizer and schedule optimizer = SGD(model.parameters(), lr=lr, weight_decay=0.0001, momentum=0.9) model, optimizer = amp.initialize( model, optimizer, opt_level='O2' if mixed_precision else 'O0', keep_batchnorm_fp32=True, loss_scale=128.0, verbosity=is_master) if world > 1: model = DistributedDataParallel(model) model.train() if 'optimizer' in state: optimizer.load_state_dict(state['optimizer']) def schedule(train_iter): if warmup and train_iter <= warmup: return 0.9 * train_iter / warmup + 0.1 return gamma**len([m for m in milestones if m <= train_iter]) scheduler = LambdaLR(optimizer, schedule) # Prepare dataset if verbose: print('Preparing dataset...') data_iterator = (DaliDataIterator if use_dali else DataIterator)( path, jitter, max_size, batch_size, stride, world, annotations, training=True) if verbose: print(data_iterator) if verbose: print(' device: {} {}'.format( world, 'cpu' if not torch.cuda.is_available() else 'gpu' if world == 1 else 'gpus')) print(' batch: {}, precision: {}'.format( batch_size, 'mixed' if mixed_precision else 'full')) print('Training model for {} iterations...'.format(iterations)) # Create TensorBoard writer if logdir is not None: from tensorboardX import SummaryWriter if is_master and verbose: print('Writing TensorBoard logs to: {}'.format(logdir)) writer = SummaryWriter(logdir=logdir) profiler = Profiler(['train', 'fw', 'bw']) iteration = state.get('iteration', 0) while iteration < iterations: cls_losses, box_losses = [], [] for i, (data, target) in enumerate(data_iterator): # Forward pass profiler.start('fw') optimizer.zero_grad() cls_loss, box_loss = model([data, target]) del data profiler.stop('fw') # Backward pass profiler.start('bw') with amp.scale_loss(cls_loss + box_loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() scheduler.step(iteration) # Reduce all losses cls_loss, box_loss = cls_loss.mean().clone(), box_loss.mean( ).clone() if world > 1: torch.distributed.all_reduce(cls_loss) torch.distributed.all_reduce(box_loss) cls_loss /= world box_loss /= world if is_master: cls_losses.append(cls_loss) box_losses.append(box_loss) if is_master and not isfinite(cls_loss + box_loss): raise RuntimeError('Loss is diverging!\n{}'.format( 'Try lowering the learning rate.')) del cls_loss, box_loss profiler.stop('bw') iteration += 1 profiler.bump('train') if is_master and (profiler.totals['train'] > 60 or iteration == iterations): focal_loss = torch.stack(list(cls_losses)).mean().item() box_loss = torch.stack(list(box_losses)).mean().item() learning_rate = optimizer.param_groups[0]['lr'] if verbose: msg = '[{:{len}}/{}]'.format(iteration, iterations, len=len(str(iterations))) msg += ' focal loss: {:.3f}'.format(focal_loss) msg += ', box loss: {:.3f}'.format(box_loss) msg += ', {:.3f}s/{}-batch'.format(profiler.means['train'], batch_size) msg += ' (fw: {:.3f}s, bw: {:.3f}s)'.format( profiler.means['fw'], profiler.means['bw']) msg += ', {:.1f} im/s'.format(batch_size / profiler.means['train']) msg += ', lr: {:.2g}'.format(learning_rate) print(msg, flush=True) if logdir is not None: writer.add_scalar('Train/Loss/Focal', focal_loss, iteration) writer.add_scalar('Train/Loss/Box', box_loss, iteration) writer.add_scalar('learning_rate', learning_rate, iteration) del box_loss, focal_loss if metrics_url: post_metrics( metrics_url, { 'focal loss': mean(cls_losses), 'box loss': mean(box_losses), 'im_s': batch_size / profiler.means['train'], 'lr': learning_rate }) # Save model weights state.update({ 'iteration': iteration, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), }) with ignore_sigint(): nn_model.save(state) profiler.reset() del cls_losses[:], box_losses[:] if val_annotations and (iteration == iterations or iteration % val_iterations == 0): infer(model, val_path, None, resize, max_size, batch_size, annotations=val_annotations, mixed_precision=mixed_precision, is_master=is_master, world=world, use_dali=use_dali, is_validation=True, verbose=False, logdir=logdir, iteration=iteration) model.train() if iteration == iterations: break if logdir is not None: writer.close()
class Learner(object): def __init__(self, model, ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, lr=7e-4, alpha=0.99, epsilon=1e-5, number_updates=int(1e6), lrschedule='linear', use_actor_critic=False, rew_loss_coef=0.0, st_loss_coef=0.0, subtree_loss_coef=0.0, nsteps=5, nenvs=1, tree_depth=0): self.max_grad_norm = max_grad_norm self.use_actor_critic = use_actor_critic self.use_reward_loss = model.predict_rewards and rew_loss_coef > 0 self.rew_loss_coef = rew_loss_coef self.use_st_loss = st_loss_coef > 0 and tree_depth > 0 self.st_loss_coef = st_loss_coef self.subtree_loss_coef = subtree_loss_coef self.use_subtree_loss = subtree_loss_coef > 0 self.model = model self.nsteps = nsteps self.nenvs = nenvs self.batch_size = nsteps * nenvs self.num_actions = model.num_actions self.tree_depth = tree_depth if USE_CUDA: self.model = self.model.cuda() if not self.use_actor_critic: self.target_model = copy.deepcopy(self.model) if USE_CUDA: self.target_model = self.target_model.cuda() self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=lr, alpha=alpha, eps=epsilon) if lrschedule == "linear": self.scheduler = LambdaLR(self.optimizer, lambda step: 1.0 - (step / number_updates)) elif lrschedule == "constant": self.scheduler = LambdaLR(self.optimizer, lambda step: 1.0) else: raise ValueError("lrschedule should be 'linear' or 'constant'") self.step = self.model.step if self.use_actor_critic: self.value = self.model.value self.ent_coef = ent_coef self.vf_coef = vf_coef else: self.value = self.target_model.value def train(self, obs, next_obs, returns, rewards, masks, actions, values): """ :param obs: [batch_size x height x width x channels] observations in NHWC :param next_obs: [batch_size x height x width x channels] one-step next states :param returns: [batch_size] n-step discounted returns with bootstrapped value :param rewards: [batch_size] 1-step rewards :param masks: [batch_size] boolean episode termination mask :param actions: [batch_size] actions taken :param values: [batch_size] predicted state values """ # compute the sequences we need to get back reward predictions action_sequences, reward_sequences, sequence_mask = build_sequences( [torch.from_numpy(actions), torch.from_numpy(rewards)], self.nenvs, self.nsteps, self.tree_depth, return_mask=True) action_sequences = cudify(action_sequences.long().squeeze(-1)) reward_sequences = make_variable(reward_sequences.squeeze(-1)) sequence_mask = make_variable(sequence_mask.squeeze(-1)) Q, V, tree_result = self.model(obs) actions = make_variable(torch.from_numpy(actions).long(), requires_grad=False) returns = make_variable(torch.from_numpy(returns), requires_grad=False) policy_loss, value_loss, reward_loss, state_loss, subtree_loss_np, policy_entropy = 0, 0, 0, 0, 0, 0 if self.use_actor_critic: values = make_variable(torch.from_numpy(values), requires_grad=False) advantages = returns - values probs = F.softmax(Q, dim=-1) log_probs = F.log_softmax(Q, dim=-1) log_probs_taken = log_probs.gather(1, actions.unsqueeze(1)).squeeze() pg_loss = -torch.mean(log_probs_taken * advantages.squeeze()) vf_loss = F.mse_loss(V, returns) entropy = -torch.mean(torch.sum(probs * log_probs, 1)) loss = pg_loss + self.vf_coef * vf_loss - self.ent_coef * entropy policy_loss = pg_loss.data.cpu().numpy() value_loss = vf_loss.data.cpu().numpy() policy_entropy = entropy.data.cpu().numpy() else: Q_taken = Q.gather(1, actions.unsqueeze(1)).squeeze() bellman_loss = F.mse_loss(Q_taken, returns) loss = bellman_loss value_loss = bellman_loss.data.cpu().numpy() if self.use_reward_loss: r_taken = get_paths(tree_result["rewards"], action_sequences, self.batch_size, self.num_actions) rew_loss = F.mse_loss(torch.cat(r_taken, 1), reward_sequences, reduce=False) rew_loss = torch.sum(rew_loss * sequence_mask) / sequence_mask.sum() loss = loss + rew_loss * self.rew_loss_coef reward_loss = rew_loss.data.cpu().numpy() if self.use_st_loss: st_embeddings = tree_result["embeddings"][0] st_targets, st_mask = build_sequences([st_embeddings.data], self.nenvs, self.nsteps, self.tree_depth, return_mask=True, offset=1) st_targets = make_variable(st_targets.view(self.batch_size, -1)) st_mask = make_variable(st_mask.view(self.batch_size, -1)) st_taken = get_paths(tree_result["embeddings"][1:], action_sequences, self.batch_size, self.num_actions) st_taken_cat = torch.cat(st_taken, 1) st_loss = F.mse_loss(st_taken_cat, st_targets, reduce=False) st_loss = torch.sum(st_loss * st_mask) / st_mask.sum() state_loss = st_loss.data.cpu().numpy() loss = loss + st_loss * self.st_loss_coef if self.use_subtree_loss: subtree_taken = get_subtree(tree_result["values"], action_sequences, self.batch_size, self.num_actions) target_subtrees = tree_result["values"][0:-1] subtree_taken_clip = time_shift_tree(subtree_taken, self.nenvs, self.nsteps, -1) target_subtrees_clip = time_shift_tree(target_subtrees, self.nenvs, self.nsteps, 1) subtree_loss = [(s_taken - s_target).pow(2).mean() for (s_taken, s_target) in zip(subtree_taken_clip, target_subtrees_clip)] subtree_loss = sum(subtree_loss) subtree_loss_np = subtree_loss.data.cpu().numpy() loss = loss + subtree_loss * self.subtree_loss_coef self.scheduler.step() self.optimizer.zero_grad() loss.backward() grad_norm = nn.utils.clip_grad_norm(self.model.parameters(), self.max_grad_norm) self.optimizer.step() return policy_loss, value_loss, reward_loss, state_loss, subtree_loss_np, policy_entropy, grad_norm