class Train: def __init__(self, model, trainloader, valloader, args): self.model = model self.trainloader = trainloader self.valloader = valloader self.args = args self.start_epoch = 0 self.best_top1 = 0.0 # Loss function and Optimizer self.loss = None self.optimizer = None self.create_optimization() # Model Loading self.load_pretrained_model() self.load_checkpoint(self.args.resume_from) # Tensorboard Writer self.summary_writer = SummaryWriter(log_dir=args.summary_dir) def train(self): for cur_epoch in range(self.start_epoch, self.args.num_epochs): # Initialize tqdm tqdm_batch = tqdm(self.trainloader, desc="Epoch-" + str(cur_epoch) + "-") # Learning rate adjustment self.adjust_learning_rate(self.optimizer, cur_epoch) # Meters for tracking the average values loss, top1, top5 = AverageTracker(), AverageTracker( ), AverageTracker() # Set the model to be in training mode (for dropout and batchnorm) self.model.train() for data, target in tqdm_batch: if self.args.cuda: data, target = data.cuda(), target.cuda() data_var, target_var = Variable(data), Variable(target) # Forward pass output = self.model(data_var) cur_loss = self.loss(output, target_var) # Optimization step self.optimizer.zero_grad() cur_loss.backward() self.optimizer.step() # Top-1 and Top-5 Accuracy Calculation cur_acc1, cur_acc5 = self.compute_accuracy(output.data, target, topk=(1, 5)) loss.update(cur_loss.data[0]) top1.update(cur_acc1[0]) top5.update(cur_acc5[0]) # Summary Writing self.summary_writer.add_scalar("epoch-loss", loss.avg, cur_epoch) self.summary_writer.add_scalar("epoch-top-1-acc", top1.avg, cur_epoch) self.summary_writer.add_scalar("epoch-top-5-acc", top5.avg, cur_epoch) # Print in console tqdm_batch.close() print("Epoch-" + str(cur_epoch) + " | " + "loss: " + str(loss.avg) + " - acc-top1: " + str(top1.avg)[:7] + "- acc-top5: " + str(top5.avg)[:7]) # Evaluate on Validation Set if cur_epoch % self.args.test_every == 0 and self.valloader: self.test(self.valloader, cur_epoch) # Checkpointing is_best = top1.avg > self.best_top1 self.best_top1 = max(top1.avg, self.best_top1) self.save_checkpoint( { 'epoch': cur_epoch + 1, 'state_dict': self.model.state_dict(), 'best_top1': self.best_top1, 'optimizer': self.optimizer.state_dict(), }, is_best) def test(self, testloader, cur_epoch=-1): loss, top1, top5 = AverageTracker(), AverageTracker(), AverageTracker() # Set the model to be in testing mode (for dropout and batchnorm) self.model.eval() for data, target in testloader: if self.args.cuda: data, target = data.cuda(), target.cuda() data_var, target_var = Variable(data, volatile=True), Variable( target, volatile=True) # Forward pass output = self.model(data_var) cur_loss = self.loss(output, target_var) # Top-1 and Top-5 Accuracy Calculation cur_acc1, cur_acc5 = self.compute_accuracy(output.data, target, topk=(1, 5)) loss.update(cur_loss.data[0]) top1.update(cur_acc1[0]) top5.update(cur_acc5[0]) if cur_epoch != -1: # Summary Writing self.summary_writer.add_scalar("test-loss", loss.avg, cur_epoch) self.summary_writer.add_scalar("test-top-1-acc", top1.avg, cur_epoch) self.summary_writer.add_scalar("test-top-5-acc", top5.avg, cur_epoch) print("Test Results" + " | " + "loss: " + str(loss.avg) + " - acc-top1: " + str(top1.avg)[:7] + "- acc-top5: " + str(top5.avg)[:7]) def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'): torch.save(state, self.args.checkpoint_dir + filename) if is_best: shutil.copyfile(self.args.checkpoint_dir + filename, self.args.checkpoint_dir + 'model_best.pth.tar') def compute_accuracy(self, output, target, topk=(1, )): """Computes the accuracy@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, idx = output.topk(maxk, 1, True, True) idx = idx.t() correct = idx.eq(target.view(1, -1).expand_as(idx)) acc_arr = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) acc_arr.append(correct_k.mul_(1.0 / batch_size)) return acc_arr def adjust_learning_rate(self, optimizer, epoch): """Sets the learning rate to the initial LR multiplied by 0.98 every epoch""" learning_rate = self.args.learning_rate * ( self.args.learning_rate_decay**epoch) for param_group in optimizer.param_groups: param_group['lr'] = learning_rate def create_optimization(self): self.loss = nn.CrossEntropyLoss() if self.args.cuda: self.loss.cuda() self.optimizer = RMSprop(self.model.parameters(), self.args.learning_rate, momentum=self.args.momentum, weight_decay=self.args.weight_decay) def load_pretrained_model(self): try: print("Loading ImageNet pretrained weights...") pretrained_dict = torch.load(self.args.pretrained_path) self.model.load_state_dict(pretrained_dict) print("ImageNet pretrained weights loaded successfully.\n") except: print("No ImageNet pretrained weights exist. Skipping...\n") def load_checkpoint(self, filename): filename = self.args.checkpoint_dir + filename try: print("Loading checkpoint '{}'".format(filename)) checkpoint = torch.load(filename) self.start_epoch = checkpoint['epoch'] self.best_top1 = checkpoint['best_top1'] self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) print("Checkpoint loaded successfully from '{}' at (epoch {})\n". format(self.args.checkpoint_dir, checkpoint['epoch'])) except: print("No checkpoint exists from '{}'. Skipping...\n".format( self.args.checkpoint_dir))
class ModelAndInfo: """ This class contains the model and optional associated information, as well as methods to create models and optimizers, move these to GPU and load state from checkpoints. Attributes are: config: the model configuration information model: the model created based on the config optimizer: the optimizer created based on the config and associated with the model checkpoint_path: the path load load checkpoint from, can be None mean_teacher_model: the mean teacher model, if and as specified by the config is_model_adjusted: whether model adjustments (which cannot be done twice) have been applied to model is_mean_teacher_model_adjusted: whether model adjustments (which cannot be done twice) have been applied to the mean teacher model checkpoint_epoch: the training epoch this model was created, if loaded from disk model_execution_mode: mode this model will be run in """ MODEL_STATE_DICT_KEY = 'state_dict' OPTIIMZER_STATE_DICT_KEY = 'opt_dict' MEAN_TEACHER_STATE_DICT_KEY = 'mean_teacher_state_dict' EPOCH_KEY = 'epoch' def __init__(self, config: ModelConfigBase, model_execution_mode: ModelExecutionMode, checkpoint_path: Optional[Path] = None): """ :param config: the model configuration information :param model_execution_mode: mode this model will be run in :param checkpoint_path: the path load load checkpoint from, can be None """ self.config = config self.checkpoint_path = checkpoint_path self.model_execution_mode = model_execution_mode self._model = None self._mean_teacher_model = None self._optimizer = None self.checkpoint_epoch = None self.is_model_adjusted = False self.is_mean_teacher_model_adjusted = False @property def model(self) -> DeviceAwareModule: if not self._model: raise ValueError("Model has not been created.") return self._model @property def optimizer(self) -> Optimizer: if not self._optimizer: raise ValueError("Optimizer has not been created.") return self._optimizer @property def mean_teacher_model(self) -> Optional[DeviceAwareModule]: if not self._mean_teacher_model and self.config.compute_mean_teacher_model: raise ValueError("Mean teacher model has not been created.") return self._mean_teacher_model @classmethod def _load_checkpoint(cls, model: DeviceAwareModule, checkpoint_path: Path, key_in_state_dict: str, use_gpu: bool) -> int: """ Loads a checkpoint of a model, may be the model or the mean teacher model. Assumes the model has already been created, and the checkpoint exists. This does not set checkpoint epoch. This method should not be called externally. Use instead try_load_checkpoint_for_model or try_load_checkpoint_for_mean_teacher_model :param model: model to load weights :param key_in_state_dict: the key for the model weights in the checkpoint state dict :return checkpoint epoch form the state dict """ logging.info(f"Loading checkpoint {checkpoint_path}") # For model debugging, allow loading a GPU trained model onto the CPU. This will clearly only work # if the model is small. map_location = None if use_gpu else 'cpu' checkpoint = torch.load(str(checkpoint_path), map_location=map_location) if isinstance(model, torch.nn.DataParallel): model.module.load_state_dict(checkpoint[key_in_state_dict]) else: model.load_state_dict(checkpoint[key_in_state_dict]) return checkpoint[ModelAndInfo.EPOCH_KEY] @classmethod def _adjust_for_gpus( cls, model: DeviceAwareModule, config: ModelConfigBase, model_execution_mode: ModelExecutionMode) -> DeviceAwareModule: """ Updates a torch model so that input mini-batches are parallelized across the batch dimension to utilise multiple gpus. If model parallel is set to True and execution is in test mode, then model is partitioned to perform full volume inference. This assumes the model has been created, that the optimizer has not yet been created, and the the model has not been adjusted twice. This method should not be called externally. Use instead adjust_model_for_gpus or adjust_mean_teacher_model_for_gpus :returns Adjusted model """ if config.use_gpu: model = model.cuda() logging.info( "Adjusting the model to use mixed precision training.") # If model parallel is set to True, then partition the network across all available gpus. if config.use_model_parallel: devices = config.get_cuda_devices() assert devices is not None # for mypy model.partition_model(devices=devices) # type: ignore else: logging.info( "Making no adjustments to the model because no GPU was found.") # Update model related config attributes (After Model Parallel Activated) config.adjust_after_mixed_precision_and_parallel(model) # DataParallel enables running the model with multiple gpus by splitting samples across GPUs # If the model is used in training mode, data parallel is activated by default. # Similarly, if model parallel is not activated, data parallel is used as a backup option use_data_parallel = (model_execution_mode == ModelExecutionMode.TRAIN ) or (not config.use_model_parallel) if config.use_gpu and use_data_parallel: logging.info("Adjusting the model to use DataParallel") # Move all layers to the default GPU before activating data parallel. # This needs to happen even though we put the model to the GPU at the beginning of the method, # but we may have spread it across multiple GPUs later. model = model.cuda() model = DataParallelModel(model, device_ids=config.get_cuda_devices()) return model def create_model(self) -> None: """ Creates a model (with temperature scaling) according to the config given. """ self._model = create_model_with_temperature_scaling(self.config) def try_load_checkpoint_for_model(self) -> bool: """ Loads a checkpoint of a model. The provided model checkpoint must match the stored model. :return True if checkpoint exists and was loaded, False otherwise. """ if self._model is None: raise ValueError( "Model must be created before it can be adjusted.") if not self.checkpoint_path: raise ValueError("No checkpoint provided") if not self.checkpoint_path.is_file(): logging.warning( f'No checkpoint found at {self.checkpoint_path} current working dir {os.getcwd()}' ) return False epoch = ModelAndInfo._load_checkpoint( model=self._model, checkpoint_path=self.checkpoint_path, key_in_state_dict=ModelAndInfo.MODEL_STATE_DICT_KEY, use_gpu=self.config.use_gpu) logging.info(f"Loaded model from checkpoint (epoch: {epoch})") self.checkpoint_epoch = epoch return True def adjust_model_for_gpus(self) -> None: """ Updates the torch model so that input mini-batches are parallelized across the batch dimension to utilise multiple gpus. If model parallel is set to True and execution is in test mode, then model is partitioned to perform full volume inference. """ if self._model is None: raise ValueError( "Model must be created before it can be adjusted.") # Adjusting twice causes an error. if self.is_model_adjusted: logging.debug("model_and_info.is_model_adjusted is already True") if self._optimizer: raise ValueError( "Create an optimizer only after creating and adjusting the model." ) self._model = ModelAndInfo._adjust_for_gpus( model=self._model, config=self.config, model_execution_mode=self.model_execution_mode) self.is_model_adjusted = True logging.debug("model_and_info.is_model_adjusted set to True") def create_summary_and_adjust_model_for_gpus(self) -> None: """ Generates the model summary, which is required for model partitioning across GPUs, and then moves the model to GPU with data parallel/model parallel by calling adjust_model_for_gpus. """ if self._model is None: raise ValueError( "Model must be created before it can be adjusted.") if self.config.is_segmentation_model: summary_for_segmentation_models(self.config, self._model) # Prepare for mixed precision training and data parallelization (no-op if already done). # This relies on the information generated in the model summary. self.adjust_model_for_gpus() def try_create_model_and_load_from_checkpoint(self) -> bool: """ Creates a model as per the config, and loads the parameters from the given checkpoint path. Also updates the checkpoint_epoch. :return True if checkpoint exists and was loaded, False otherwise. """ self.create_model() if self.checkpoint_path: # Load the stored model. If there is no checkpoint present, return immediately. return self.try_load_checkpoint_for_model() return True def try_create_model_load_from_checkpoint_and_adjust(self) -> bool: """ Creates a model as per the config, and loads the parameters from the given checkpoint path. The model is then adjusted for data parallelism and mixed precision. Also updates the checkpoint_epoch. :return True if checkpoint exists and was loaded, False otherwise. """ success = self.try_create_model_and_load_from_checkpoint() self.create_summary_and_adjust_model_for_gpus() return success def create_mean_teacher_model(self) -> None: """ Creates a model (with temperature scaling) according to the config given. """ self._mean_teacher_model = create_model_with_temperature_scaling( self.config) def try_load_checkpoint_for_mean_teacher_model(self) -> bool: """ Loads a checkpoint of a model. The provided model checkpoint must match the stored model. :return True if checkpoint exists and was loaded, False otherwise. """ if self._mean_teacher_model is None: raise ValueError( "Mean teacher model must be created before it can be adjusted." ) if not self.checkpoint_path: raise ValueError("No checkpoint provided") if not self.checkpoint_path.is_file(): logging.warning( f'No checkpoint found at {self.checkpoint_path} current working dir {os.getcwd()}' ) return False epoch = ModelAndInfo._load_checkpoint( model=self._mean_teacher_model, checkpoint_path=self.checkpoint_path, key_in_state_dict=ModelAndInfo.MEAN_TEACHER_STATE_DICT_KEY, use_gpu=self.config.use_gpu) logging.info( f"Loaded mean teacher model from checkpoint (epoch: {epoch})") self.checkpoint_epoch = epoch return True def adjust_mean_teacher_model_for_gpus(self) -> None: """ Updates the torch model so that input mini-batches are parallelized across the batch dimension to utilise multiple gpus. If model parallel is set to True and execution is in test mode, then model is partitioned to perform full volume inference. """ if self._mean_teacher_model is None: raise ValueError( "Mean teacher model must be created before it can be adjusted." ) # Adjusting twice causes an error. if self.is_mean_teacher_model_adjusted: logging.debug( "model_and_info.is_mean_teacher_model_adjusted is already True" ) self._mean_teacher_model = ModelAndInfo._adjust_for_gpus( model=self._mean_teacher_model, config=self.config, model_execution_mode=self.model_execution_mode) self.is_mean_teacher_model_adjusted = True logging.debug( "model_and_info.is_mean_teacher_model_adjusted set to True") def create_summary_and_adjust_mean_teacher_model_for_gpus(self) -> None: """ Generates the model summary, which is required for model partitioning across GPUs, and then moves the model to GPU with data parallel/model parallel by calling adjust_model_for_gpus. """ if self._mean_teacher_model is None: raise ValueError( "Mean teacher model must be created before it can be adjusted." ) if self.config.is_segmentation_model: summary_for_segmentation_models(self.config, self._mean_teacher_model) # Prepare for mixed precision training and data parallelization (no-op if already done). # This relies on the information generated in the model summary. self.adjust_mean_teacher_model_for_gpus() def try_create_mean_teacher_model_and_load_from_checkpoint(self) -> bool: """ Creates a model as per the config, and loads the parameters from the given checkpoint path. Also updates the checkpoint_epoch. :return True if checkpoint exists and was loaded, False otherwise. """ self.create_mean_teacher_model() if self.checkpoint_path: # Load the stored model. If there is no checkpoint present, return immediately. return self.try_load_checkpoint_for_mean_teacher_model() return True def try_create_mean_teacher_model_load_from_checkpoint_and_adjust( self) -> bool: """ Creates a model as per the config, and loads the parameters from the given checkpoint path. The model is then adjusted for data parallelism and mixed precision. Also updates the checkpoint_epoch. :return True if checkpoint exists and was loaded, False otherwise. """ success = self.try_create_mean_teacher_model_and_load_from_checkpoint() self.create_summary_and_adjust_mean_teacher_model_for_gpus() return success def create_optimizer(self) -> None: """ Creates a torch optimizer for the given model, and stores it as an instance variable in the current object. """ # Make sure model is created before we create optimizer if self._model is None: raise ValueError( "Model checkpoint must be created before optimizer checkpoint can be loaded." ) # Select optimizer type if self.config.optimizer_type in [ OptimizerType.Adam, OptimizerType.AMSGrad ]: self._optimizer = torch.optim.Adam( self._model.parameters(), self.config.l_rate, self.config.adam_betas, self.config.opt_eps, self.config.weight_decay, amsgrad=self.config.optimizer_type == OptimizerType.AMSGrad) elif self.config.optimizer_type == OptimizerType.SGD: self._optimizer = torch.optim.SGD( self._model.parameters(), self.config.l_rate, self.config.momentum, weight_decay=self.config.weight_decay) elif self.config.optimizer_type == OptimizerType.RMSprop: self._optimizer = RMSprop(self._model.parameters(), self.config.l_rate, self.config.rms_alpha, self.config.opt_eps, self.config.weight_decay, self.config.momentum) else: raise NotImplementedError( f"Optimizer type {self.config.optimizer_type.value} is not implemented" ) def try_load_checkpoint_for_optimizer(self) -> bool: """ Loads a checkpoint of an optimizer. :return True if the checkpoint exists and optimizer state loaded, False otherwise """ if self._optimizer is None: raise ValueError( "Optimizer must be created before optimizer checkpoint can be loaded." ) if not self.checkpoint_path: logging.warning("No checkpoint path provided.") return False if not self.checkpoint_path.is_file(): logging.warning( f'No checkpoint found at {self.checkpoint_path} current working dir {os.getcwd()}' ) return False logging.info(f"Loading checkpoint {self.checkpoint_path}") # For model debugging, allow loading a GPU trained model onto the CPU. This will clearly only work # if the model is small. map_location = None if self.config.use_gpu else 'cpu' checkpoint = torch.load(str(self.checkpoint_path), map_location=map_location) if self._optimizer: self._optimizer.load_state_dict( checkpoint[ModelAndInfo.OPTIIMZER_STATE_DICT_KEY]) logging.info( f"Loaded optimizer from checkpoint (epoch: {checkpoint[ModelAndInfo.EPOCH_KEY]})" ) self.checkpoint_epoch = checkpoint[ModelAndInfo.EPOCH_KEY] return True def try_create_optimizer_and_load_from_checkpoint(self) -> bool: """ Creates an optimizer and loads its state from a checkpoint. :return True if the checkpoint exists and optimizer state loaded, False otherwise """ self.create_optimizer() if self.checkpoint_path: return self.try_load_checkpoint_for_optimizer() return True def save_checkpoint(self, epoch: int) -> None: """ Saves a checkpoint of the current model and optimizer_type parameters in the specified folder and uploads it to the output blob storage of the current run context. The checkpoint's name for epoch 123 would be 123_checkpoint.pth.tar. :param epoch: The last epoch used to train the model. """ logging.getLogger().disabled = True model_state_dict = self.model.module.state_dict() \ if isinstance(self.model, torch.nn.DataParallel) else self.model.state_dict() checkpoint_file_path = self.config.get_path_to_checkpoint(epoch) info_to_store = { ModelAndInfo.EPOCH_KEY: epoch, ModelAndInfo.MODEL_STATE_DICT_KEY: model_state_dict, ModelAndInfo.OPTIIMZER_STATE_DICT_KEY: self.optimizer.state_dict() } if self.config.compute_mean_teacher_model: assert self.mean_teacher_model is not None # for mypy, getter has this built in mean_teacher_model_state_dict = self.mean_teacher_model.module.state_dict() \ if isinstance(self.mean_teacher_model, torch.nn.DataParallel) \ else self.mean_teacher_model.state_dict() info_to_store[ ModelAndInfo. MEAN_TEACHER_STATE_DICT_KEY] = mean_teacher_model_state_dict torch.save(info_to_store, checkpoint_file_path) logging.getLogger().disabled = False logging.info( "Saved model checkpoint for epoch {epoch} to {checkpoint_file_path}" )
def main(args): # Select the hardware device to use for inference. if torch.cuda.is_available(): device = torch.device('cuda', torch.cuda.current_device()) torch.backends.cudnn.benchmark = True else: device = torch.device('cpu') # Disable gradient calculations by default. torch.set_grad_enabled(False) # create checkpoint dir os.makedirs(args.checkpoint, exist_ok=True) if args.arch == 'hg1': model = hg1(pretrained=False) elif args.arch == 'hg2': model = hg2(pretrained=False) elif args.arch == 'hg8': model = hg8(pretrained=False) else: raise Exception('unrecognised model architecture: ' + args.arch) model = DataParallel(model).to(device) optimizer = RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) best_acc = 0 # optionally resume from a checkpoint if args.resume: assert os.path.isfile(args.resume) print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_acc = checkpoint['best_acc'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) logger = Logger(os.path.join(args.checkpoint, 'log.txt'), resume=True) else: logger = Logger(os.path.join(args.checkpoint, 'log.txt')) logger.set_names( ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc']) # create data loader train_dataset = Mpii(args.image_path, is_train=True) train_loader = DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers, pin_memory=True) val_dataset = Mpii(args.image_path, is_train=False) val_loader = DataLoader(val_dataset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=True) # train and eval lr = args.lr for epoch in trange(args.start_epoch, args.epochs, desc='Overall', ascii=True): lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma) # train for one epoch train_loss, train_acc = do_training_epoch(train_loader, model, device, Mpii.DATA_INFO, optimizer, acc_joints=Mpii.ACC_JOINTS) # evaluate on validation set valid_loss, valid_acc, predictions = do_validation_epoch( val_loader, model, device, Mpii.DATA_INFO, False, acc_joints=Mpii.ACC_JOINTS) # print metrics tqdm.write( f'[{epoch + 1:3d}/{args.epochs:3d}] lr={lr:0.2e} ' f'train_loss={train_loss:0.4f} train_acc={100 * train_acc:0.2f} ' f'valid_loss={valid_loss:0.4f} valid_acc={100 * valid_acc:0.2f}') # append logger file logger.append( [epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc]) logger.plot_to_file(os.path.join(args.checkpoint, 'log.svg'), ['Train Acc', 'Val Acc']) # remember best acc and save checkpoint is_best = valid_acc > best_acc best_acc = max(valid_acc, best_acc) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_acc': best_acc, 'optimizer': optimizer.state_dict(), }, predictions, is_best, checkpoint=args.checkpoint, snapshot=args.snapshot) logger.close()
def train_sim(epoch_num=10, optim_type='ACGD', startPoint=None, start_n=0, z_dim=128, batchsize=64, l2_penalty=0.0, momentum=0.0, log=False, loss_name='WGAN', model_name='dc', model_config=None, data_path='None', show_iter=100, logdir='test', dataname='CIFAR10', device='cpu', gpu_num=1): lr_d = 1e-4 lr_g = 1e-4 dataset = get_data(dataname=dataname, path=data_path) dataloader = DataLoader(dataset=dataset, batch_size=batchsize, shuffle=True, num_workers=4) D, G = get_model(model_name=model_name, z_dim=z_dim, configs=model_config) D.apply(weights_init_d).to(device) G.apply(weights_init_g).to(device) optim_d = RMSprop(D.parameters(), lr=lr_d) optim_g = RMSprop(G.parameters(), lr=lr_g) if startPoint is not None: chk = torch.load(startPoint) D.load_state_dict(chk['D']) G.load_state_dict(chk['G']) optim_d.load_state_dict(chk['d_optim']) optim_g.load_state_dict(chk['g_optim']) print('Start from %s' % startPoint) if gpu_num > 1: D = nn.DataParallel(D, list(range(gpu_num))) G = nn.DataParallel(G, list(range(gpu_num))) timer = time.time() count = 0 if 'DCGAN' in model_name: fixed_noise = torch.randn((64, z_dim, 1, 1), device=device) else: fixed_noise = torch.randn((64, z_dim), device=device) for e in range(epoch_num): print('======Epoch: %d / %d======' % (e, epoch_num)) for real_x in dataloader: real_x = real_x[0].to(device) d_real = D(real_x) if 'DCGAN' in model_name: z = torch.randn((d_real.shape[0], z_dim, 1, 1), device=device) else: z = torch.randn((d_real.shape[0], z_dim), device=device) fake_x = G(z) d_fake = D(fake_x) loss = get_loss(name=loss_name, g_loss=False, d_real=d_real, d_fake=d_fake, l2_weight=l2_penalty, D=D) D.zero_grad() G.zero_grad() loss.backward() optim_d.step() optim_g.step() if count % show_iter == 0: time_cost = time.time() - timer print('Iter :%d , Loss: %.5f, time: %.3fs' % (count, loss.item(), time_cost)) timer = time.time() with torch.no_grad(): fake_img = G(fixed_noise).detach() path = 'figs/%s_%s/' % (dataname, logdir) if not os.path.exists(path): os.makedirs(path) vutils.save_image(fake_img, path + 'iter_%d.png' % (count + start_n), normalize=True) save_checkpoint( path=logdir, name='%s-%s%.3f_%d.pth' % (optim_type, model_name, lr_g, count + start_n), D=D, G=G, optimizer=optim_d, g_optimizer=optim_g) if wandb and log: wandb.log({ 'Real score': d_real.mean().item(), 'Fake score': d_fake.mean().item(), 'Loss': loss.item() }) count += 1
class DQNAgent(TrainingAgent): def __init__(self, input_shape, action_space, seed, device, model, gamma, alpha, tau, batch_size,update, replay, buffer_size, env, decay = 200, path = 'model',num_epochs= 0, max_step = 50000, learn_interval = 20): '''Initialise a DQNAgent Object buffer_size : size of replay buffer to sample from gamma : discount rate alpha : learn rate replay. : after which replay buffer loading to be started update : update interval of model parameters every x instances of back propagation replay. : after which replay buffer loading to be started learn_interval: tick for learning rate ''' super(DQNAgent,self).__init__( input_shape ,action_space ,seed ,device,model, gamma, alpha, tau, batch_size, max_step, env,num_epochs ,path) self.buffer_size = buffer_size self.update = update self.replay = replay self.interval = learn_interval # Q-Network self.policy_net = self.model(input_shape, action_space).to(self.device) self.target_net = self.model(input_shape, action_space).to(self.device) self.optimiser = RMSprop(self.policy_net.parameters(), lr=self.alpha) # Replay Memory self.memory = ReplayMemory(self.buffer_size, self.batch_size, self.seed, self.device) # Timestep self.t_step = 0 self.l_step = 0 self.EPSILON_START = 1.0 self.EPSILON_FINAL = 0.02 self.EPS_DECAY = decay self.epsilon_delta = lambda frame_idx: self.EPSILON_FINAL + (self.EPSILON_START - self.EPSILON_FINAL) * exp(-1. * frame_idx / self.EPS_DECAY) def step(self, state, action, reward, next_state, done): ''' Step of learning and taking environment action. ''' # Save experience into replay buffer self.memory.add(state, action, reward, next_state, done) # Learn every update % timestep self.t_step = (self.t_step + 1) % self.interval if self.t_step == 0: # if there are enough samples in the memory, get a random subset and learn if len(self.memory) > self.replay: experience = self.memory.sample() print('learning') self.learn(experience) def action(self, state, eps=0.): ''' Returns action for given state as per current policy''' #Unpack the state state = torch.from_numpy(state).unsqueeze(0).to(self.device) if rand.rand() > eps: # Eps Greedy action selections action_val = self.policy_net(state) return np.argmax(action_val.cpu().data.numpy()) else: return random.choice(np.arange(self.action_space)) def learn(self, exp): state, action, reward, next_state, done = exp # Get expected Q values from Policy Model Q_expt_current = self.policy_net(state) Q_expt = Q_expt_current.gather(1, action.unsqueeze(1)).squeeze(1) # Get max predicted Q values for next state from target model Q_target_next = self.target_net(next_state).detach().max(1)[0] # Compute Q targets for current states Q_target = reward + (self.gamma * Q_target_next * (1 - done)) # Compute Loss loss = torch.nn.functional.mse_loss(Q_expt, Q_target) # Minimize loss self.optimiser.zero_grad() loss.backward() self.optimiser.step() self.l_step = (self.l_step +1) % self.update if self.t_step == 0: self.soft_update(self.policy_net, self.target_net, self.tau) def model_dict(self)-> dict: ''' To save models''' return {'policy_net': self.policy_net.state_dict(), 'target_net': self.target_net.state_dict(), 'optimizer': self.optimiser.state_dict(), 'num_epoch': self.num_epochs,'scores': self.scores} def load_model(self, state_dict,eval = True): '''Load Parameters and Model Information from prior training for continuation of training''' self.policy_net.load_state_dict(state_dict['policy_net']) self.target_net.load_state_dict(state_dict['target_net']) self.optimiser.load_state_dict(state_dict['optimizer']) self.scores = state_dict['scores'] if eval: self.policy_net.eval() self.target_net.eval() else: self.policy_net.train() self.target_net.train() #Load the model self.num_epochs = state_dict['num_epoch'] # θ'=θ×τ+θ'×(1−τ) def soft_update(self, policy_model, target_model, tau): for t_param, p_param in zip(target_model.parameters(), policy_model.parameters()): t_param.data.copy_(tau * p_param.data + (1.0 - tau) * t_param.data) def train(self, n_episodes=1000,render= False): """ n_episodes: maximum number of training episodes Saves Model every 100 Epochs """ filename = get_filename() self.env.render(render) # Toggles the render on for i_episode in range(n_episodes): self.num_epochs += 1 state = self.stack_frames(None, self.reset(), True) score = 0 eps = self.epsilon_delta(self.num_epochs) while True: action = self.action(state, eps) next_state, reward, done, info = self.env.step(action) score += reward next_state = self.stack_frames(state, next_state, False) self.step(state, action, reward, next_state, done) state = next_state if done: break self.scores.append(score) # save most recent score # Every 100 training if i_episode % 100 == 0: self.save_obj(self.model_dict(), os.path.join(self.path, filename)) print(f"Creating plot") # Plot a figure fig = plt.figure() # Add a subplot # ax = fig.add_subplot(111) # Plot the graph plt.plot(np.arange(len(self.scores)), self.scores) # Add labels plt.xlabel('Episode #') plt.ylabel('Score') # Save the plot plt.savefig(f'{i_episode} plot.png') print(f"Plot saved") # Return the scores. return self.scores
class ModelAndInfo: """ This class contains the model and optional associated information, as well as methods to create models and optimizers, move these to GPU and load state from checkpoints. Attributes are: config: the model configuration information model: the model created based on the config optimizer: the optimizer created based on the config and associated with the model checkpoint_path: the path load load checkpoint from, can be None is_mean_teacher: whether this is (intended to be) a mean teacher model is_adjusted: whether model adjustments (which cannot be done twice) have been applied checkpoint_epoch: the training epoch this model was created, if loaded from disk model_execution_mode: mode this model will be run in """ def __init__(self, config: ModelConfigBase, model_execution_mode: ModelExecutionMode, is_mean_teacher: bool = False, checkpoint_path: Optional[Path] = None): """ :param config: the model configuration information :param model_execution_mode: mode this model will be run in :param is_mean_teacher: whether this is (intended to be) a mean teacher model :param checkpoint_path: the path load load checkpoint from, can be None """ self.config = config self.is_mean_teacher = is_mean_teacher self.checkpoint_path = checkpoint_path self.model_execution_mode = model_execution_mode self._model = None self._optimizer = None self.checkpoint_epoch = None self.is_adjusted = False @property def model(self) -> DeviceAwareModule: if not self._model: raise ValueError("Model has not been created.") return self._model @property def optimizer(self) -> Optimizer: if not self._optimizer: raise ValueError("Optimizer has not been created.") return self._optimizer def to_cuda(self) -> None: """ Moves the model to GPU """ if self._model is None: raise ValueError( "Model must be created before it can be moved to GPU.") self._model = self._model.cuda() def set_data_parallel(self, device_ids: Optional[List[Any]]) -> None: if self._model is None: raise ValueError( "Model must be created before it can be moved to Data Parellel." ) self._model = DataParallelModel(self._model, device_ids=device_ids) def create_model(self) -> None: """ Creates a model (with temperature scaling) according to the config given. """ self._model = create_model_with_temperature_scaling(self.config) def try_load_checkpoint_for_model(self) -> bool: """ Loads a checkpoint of a model. The provided model checkpoint must match the stored model. :return True if checkpoint exists and was loaded, False otherwise. """ if self._model is None: raise ValueError( "Model must be created before it can be adjusted.") if not self.checkpoint_path: raise ValueError("No checkpoint provided") if not self.checkpoint_path.is_file(): logging.warning( f'No checkpoint found at {self.checkpoint_path} current working dir {os.getcwd()}' ) return False logging.info(f"Loading checkpoint {self.checkpoint_path}") # For model debugging, allow loading a GPU trained model onto the CPU. This will clearly only work # if the model is small. map_location = None if self.config.use_gpu else 'cpu' checkpoint = torch.load(str(self.checkpoint_path), map_location=map_location) if isinstance(self._model, torch.nn.DataParallel): self._model.module.load_state_dict(checkpoint['state_dict']) else: self._model.load_state_dict(checkpoint['state_dict']) logging.info( f"Loaded model from checkpoint (epoch: {checkpoint['epoch']})") self.checkpoint_epoch = checkpoint['epoch'] return True def adjust_model_for_gpus(self) -> None: """ Updates the torch model so that input mini-batches are parallelized across the batch dimension to utilise multiple gpus. If model parallel is set to True and execution is in test mode, then model is partitioned to perform full volume inference. """ if self._model is None: raise ValueError( "Model must be created before it can be adjusted.") # Adjusting twice causes an error. if self.is_adjusted: logging.debug("model_and_info.is_adjusted is already True") if self._optimizer: raise ValueError( "Create an optimizer only after creating and adjusting the model." ) if self.config.use_gpu: self.to_cuda() logging.info( "Adjusting the model to use mixed precision training.") # If model parallel is set to True, then partition the network across all available gpus. if self.config.use_model_parallel: devices = self.config.get_cuda_devices() assert devices is not None # for mypy self._model.partition_model(devices=devices) # type: ignore else: logging.info( "Making no adjustments to the model because no GPU was found.") # Update model related config attributes (After Model Parallel Activated) self.config.adjust_after_mixed_precision_and_parallel(self._model) # DataParallel enables running the model with multiple gpus by splitting samples across GPUs # If the model is used in training mode, data parallel is activated by default. # Similarly, if model parallel is not activated, data parallel is used as a backup option use_data_parallel = (self.model_execution_mode == ModelExecutionMode.TRAIN) or ( not self.config.use_model_parallel) if self.config.use_gpu and use_data_parallel: logging.info("Adjusting the model to use DataParallel") # Move all layers to the default GPU before activating data parallel. # This needs to happen even though we put the model to the GPU at the beginning of the method, # but we may have spread it across multiple GPUs later. self.to_cuda() self.set_data_parallel(device_ids=self.config.get_cuda_devices()) self.is_adjusted = True logging.debug("model_and_info.is_adjusted set to True") def create_summary_and_adjust_model_for_gpus(self) -> None: """ Generates the model summary, which is required for model partitioning across GPUs, and then moves the model to GPU with data parallel/model parallel by calling adjust_model_for_gpus. """ if self._model is None: raise ValueError( "Model must be created before it can be adjusted.") if self.config.is_segmentation_model: summary_for_segmentation_models(self.config, self._model) # Prepare for mixed precision training and data parallelization (no-op if already done). # This relies on the information generated in the model summary. self.adjust_model_for_gpus() def try_create_model_and_load_from_checkpoint(self) -> bool: """ Creates a model as per the config, and loads the parameters from the given checkpoint path. Also updates the checkpoint_epoch. :return True if checkpoint exists and was loaded, False otherwise. """ self.create_model() # for mypy assert self._model if self.checkpoint_path: # Load the stored model. If there is no checkpoint present, return immediately. return self.try_load_checkpoint_for_model() return True def try_create_model_load_from_checkpoint_and_adjust(self) -> bool: """ Creates a model as per the config, and loads the parameters from the given checkpoint path. The model is then adjusted for data parallelism and mixed precision, running in TEST mode. Also updates the checkpoint_epoch. :return True if checkpoint exists and was loaded, False otherwise. """ success = self.try_create_model_and_load_from_checkpoint() self.create_summary_and_adjust_model_for_gpus() return success def create_optimizer(self) -> None: """ Creates a torch optimizer for the given model, and stores it as an instance variable in the current object. """ # Make sure model is created before we create optimizer if self._model is None: raise ValueError( "Model checkpoint must be created before optimizer checkpoint can be loaded." ) # Select optimizer type if self.config.optimizer_type in [ OptimizerType.Adam, OptimizerType.AMSGrad ]: self._optimizer = torch.optim.Adam( self._model.parameters(), self.config.l_rate, self.config.adam_betas, self.config.opt_eps, self.config.weight_decay, amsgrad=self.config.optimizer_type == OptimizerType.AMSGrad) elif self.config.optimizer_type == OptimizerType.SGD: self._optimizer = torch.optim.SGD( self._model.parameters(), self.config.l_rate, self.config.momentum, weight_decay=self.config.weight_decay) elif self.config.optimizer_type == OptimizerType.RMSprop: self._optimizer = RMSprop(self._model.parameters(), self.config.l_rate, self.config.rms_alpha, self.config.opt_eps, self.config.weight_decay, self.config.momentum) else: raise NotImplementedError( f"Optimizer type {self.config.optimizer_type.value} is not implemented" ) def try_load_checkpoint_for_optimizer(self) -> bool: """ Loads a checkpoint of an optimizer. :return True if the checkpoint exists and optimizer state loaded, False otherwise """ if self._optimizer is None: raise ValueError( "Optimizer must be created before optimizer checkpoint can be loaded." ) if not self.checkpoint_path: logging.warning("No checkpoint path provided.") return False if not self.checkpoint_path.is_file(): logging.warning( f'No checkpoint found at {self.checkpoint_path} current working dir {os.getcwd()}' ) return False logging.info(f"Loading checkpoint {self.checkpoint_path}") # For model debugging, allow loading a GPU trained model onto the CPU. This will clearly only work # if the model is small. map_location = None if self.config.use_gpu else 'cpu' checkpoint = torch.load(str(self.checkpoint_path), map_location=map_location) if self._optimizer: self._optimizer.load_state_dict(checkpoint['opt_dict']) logging.info( "Loaded optimizer from checkpoint (epoch: {checkpoint['epoch']})") self.checkpoint_epoch = checkpoint['epoch'] return True def try_create_optimizer_and_load_from_checkpoint(self) -> bool: """ Creates an optimizer and loads its state from a checkpoint. :return True if the checkpoint exists and optimizer state loaded, False otherwise """ self.create_optimizer() if self.checkpoint_path: return self.try_load_checkpoint_for_optimizer() return True
loss = lossCri(predictTensor, batchLabels) loss.backward() optimizer.step() scheduler.step() if trainingTimes % display_step == 0: print("#################") print("Predict tensor is ", predictTensor) print("Labels are ", batchLabels) print("Learning rate is ", optimizer.state_dict()['param_groups'][0]["lr"]) print("Loss is ", loss) print("Training time is ", trainingTimes) learning_rate = scheduler.calculateLearningRate() state_dic = optimizer.state_dict() state_dic["param_groups"][0]["lr"] = float(learning_rate) optimizer.load_state_dict(state_dic) trainingTimes += 1 if trainingTimes % save_model_steps == 0: torch.save( model.state_dict(), weight_save_path + "ALBERT_" + str(trainingTimes) + ".pth") else: model.eval() model.load_state_dict( torch.load(weight_save_path + "ALBERT_" + str(testModelSelect) + ".pth")) predictLabels = [] truthLabels = [] print("POSITIVE SAMPLES PREDICT.") k = 0 for sample in test_positive_samples:
class OffPGLearner: def __init__(self, mac, scheme, logger, args): self.args = args self.n_agents = args.n_agents self.n_actions = args.n_actions self.mac = mac self.logger = logger self.last_target_update_step = 0 self.critic_training_steps = 0 self.log_stats_t = -self.args.learner_log_interval - 1 self.critic = OffPGCritic(scheme, args) self.mixer = QMixer(args) self.target_critic = copy.deepcopy(self.critic) self.target_mixer = copy.deepcopy(self.mixer) self.agent_params = list(mac.parameters()) self.critic_params = list(self.critic.parameters()) self.mixer_params = list(self.mixer.parameters()) self.params = self.agent_params + self.critic_params self.c_params = self.critic_params + self.mixer_params self.agent_optimiser = RMSprop(params=self.agent_params, lr=args.lr) self.critic_optimiser = RMSprop(params=self.critic_params, lr=args.lr) self.mixer_optimiser = RMSprop(params=self.mixer_params, lr=args.lr) print('Mixer Size: ') print(get_parameters_num(list(self.c_params))) def train(self, batch: EpisodeBatch, t_env: int, log): # Get the relevant quantities bs = batch.batch_size max_t = batch.max_seq_length actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() avail_actions = batch["avail_actions"][:, :-1] mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) mask = mask.repeat(1, 1, self.n_agents).view(-1) states = batch["state"][:, :-1] #build q inputs = self.critic._build_inputs(batch, bs, max_t) q_vals = self.critic.forward(inputs).detach()[:, :-1] mac_out = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length - 1): agent_outs = self.mac.forward(batch, t=t) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1) # Concat over time # Mask out unavailable actions, renormalise (as in action selection) mac_out[avail_actions == 0] = 0 mac_out = mac_out/mac_out.sum(dim=-1, keepdim=True) mac_out[avail_actions == 0] = 0 # Calculated baseline q_taken = th.gather(q_vals, dim=3, index=actions).squeeze(3) pi = mac_out.view(-1, self.n_actions) baseline = th.sum(mac_out * q_vals, dim=-1).view(-1).detach() # Calculate policy grad with mask pi_taken = th.gather(pi, dim=1, index=actions.reshape(-1, 1)).squeeze(1) pi_taken[mask == 0] = 1.0 log_pi_taken = th.log(pi_taken) coe = self.mixer.k(states).view(-1) advantages = (q_taken.view(-1) - baseline) # advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) coma_loss = - ((coe * advantages.detach() * log_pi_taken) * mask).sum() / mask.sum() # dist_entropy = Categorical(pi).entropy().view(-1) # dist_entropy[mask == 0] = 0 # fill nan # entropy_loss = (dist_entropy * mask).sum() / mask.sum() # loss = coma_loss - self.args.ent_coef * entropy_loss / entropy_loss.item() loss = coma_loss # Optimise agents self.agent_optimiser.zero_grad() loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.agent_params, self.args.grad_norm_clip) self.agent_optimiser.step() #compute parameters sum for debugging p_sum = 0. for p in self.agent_params: p_sum += p.data.abs().sum().item() / 100.0 if t_env - self.log_stats_t >= self.args.learner_log_interval: ts_logged = len(log["critic_loss"]) for key in ["critic_loss", "critic_grad_norm", "td_error_abs", "q_taken_mean", "target_mean", "q_max_mean", "q_min_mean", "q_max_var", "q_min_var"]: self.logger.log_stat(key, sum(log[key])/ts_logged, t_env) self.logger.log_stat("q_max_first", log["q_max_first"], t_env) self.logger.log_stat("q_min_first", log["q_min_first"], t_env) #self.logger.log_stat("advantage_mean", (advantages * mask).sum().item() / mask.sum().item(), t_env) # self.logger.log_stat("entropy_loss", entropy_loss.item(), t_env) self.logger.log_stat("coma_loss", coma_loss.item(), t_env) self.logger.log_stat("agent_grad_norm", grad_norm, t_env) self.logger.log_stat("pi_max", (pi.max(dim=1)[0] * mask).sum().item() / mask.sum().item(), t_env) self.log_stats_t = t_env def train_critic(self, on_batch, best_batch=None, log=None): bs = on_batch.batch_size max_t = on_batch.max_seq_length rewards = on_batch["reward"][:, :-1] actions = on_batch["actions"][:, :] terminated = on_batch["terminated"][:, :-1].float() mask = on_batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = on_batch["avail_actions"][:] states = on_batch["state"] #build_target_q target_inputs = self.target_critic._build_inputs(on_batch, bs, max_t) target_q_vals = self.target_critic.forward(target_inputs).detach() targets_taken = self.target_mixer(th.gather(target_q_vals, dim=3, index=actions).squeeze(3), states) target_q = build_td_lambda_targets(rewards, terminated, mask, targets_taken, self.n_agents, self.args.gamma, self.args.td_lambda).detach() inputs = self.critic._build_inputs(on_batch, bs, max_t) if best_batch is not None: best_target_q, best_inputs, best_mask, best_actions, best_mac_out= self.train_critic_best(best_batch) log["best_reward"] = th.mean(best_batch["reward"][:, :-1].squeeze(2).sum(-1), dim=0) target_q = th.cat((target_q, best_target_q), dim=0) inputs = th.cat((inputs, best_inputs), dim=0) mask = th.cat((mask, best_mask), dim=0) actions = th.cat((actions, best_actions), dim=0) states = th.cat((states, best_batch["state"]), dim=0) #train critic for t in range(max_t - 1): mask_t = mask[:, t:t+1] if mask_t.sum() < 0.5: continue q_vals = self.critic.forward(inputs[:, t:t+1]) q_ori = q_vals q_vals = th.gather(q_vals, 3, index=actions[:, t:t+1]).squeeze(3) q_vals = self.mixer.forward(q_vals, states[:, t:t+1]) target_q_t = target_q[:, t:t+1].detach() q_err = (q_vals - target_q_t) * mask_t critic_loss = (q_err ** 2).sum() / mask_t.sum() self.critic_optimiser.zero_grad() self.mixer_optimiser.zero_grad() critic_loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.c_params, self.args.grad_norm_clip) self.critic_optimiser.step() self.mixer_optimiser.step() self.critic_training_steps += 1 log["critic_loss"].append(critic_loss.item()) log["critic_grad_norm"].append(grad_norm) mask_elems = mask_t.sum().item() log["td_error_abs"].append((q_err.abs().sum().item() / mask_elems)) log["target_mean"].append((target_q_t * mask_t).sum().item() / mask_elems) log["q_taken_mean"].append((q_vals * mask_t).sum().item() / mask_elems) log["q_max_mean"].append((th.mean(q_ori.max(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems) log["q_min_mean"].append((th.mean(q_ori.min(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems) log["q_max_var"].append((th.var(q_ori.max(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems) log["q_min_var"].append((th.var(q_ori.min(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems) if (t == 0): log["q_max_first"] = (th.mean(q_ori.max(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems log["q_min_first"] = (th.mean(q_ori.min(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems #update target network if (self.critic_training_steps - self.last_target_update_step) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_step = self.critic_training_steps def train_critic_best(self, batch): bs = batch.batch_size max_t = batch.max_seq_length rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"][:] states = batch["state"] with th.no_grad(): # pr for all actions of the episode mac_out = [] self.mac.init_hidden(bs) for i in range(max_t): agent_outs = self.mac.forward(batch, t=i) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1).detach() # Mask out unavailable actions, renormalise (as in action selection) mac_out[avail_actions == 0] = 0 mac_out = mac_out / mac_out.sum(dim=-1, keepdim=True) mac_out[avail_actions == 0] = 0 critic_mac = th.gather(mac_out, 3, actions).squeeze(3).prod(dim=2, keepdim=True) #target_q take target_inputs = self.target_critic._build_inputs(batch, bs, max_t) target_q_vals = self.target_critic.forward(target_inputs).detach() targets_taken = self.target_mixer(th.gather(target_q_vals, dim=3, index=actions).squeeze(3), states) #expected q exp_q = self.build_exp_q(target_q_vals, mac_out, states).detach() # td-error targets_taken[:, -1] = targets_taken[:, -1] * (1 - th.sum(terminated, dim=1)) exp_q[:, -1] = exp_q[:, -1] * (1 - th.sum(terminated, dim=1)) targets_taken[:, :-1] = targets_taken[:, :-1] * mask exp_q[:, :-1] = exp_q[:, :-1] * mask td_q = (rewards + self.args.gamma * exp_q[:, 1:] - targets_taken[:, :-1]) * mask #compute target target_q = build_target_q(td_q, targets_taken[:, :-1], critic_mac, mask, self.args.gamma, self.args.tb_lambda, self.args.step).detach() inputs = self.critic._build_inputs(batch, bs, max_t) return target_q, inputs, mask, actions, mac_out def build_exp_q(self, target_q_vals, mac_out, states): target_exp_q_vals = th.sum(target_q_vals * mac_out, dim=3) target_exp_q_vals = self.target_mixer.forward(target_exp_q_vals, states) return target_exp_q_vals def _update_targets(self): self.target_critic.load_state_dict(self.critic.state_dict()) self.target_mixer.load_state_dict(self.mixer.state_dict()) self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() self.critic.cuda() self.mixer.cuda() self.target_critic.cuda() self.target_mixer.cuda() def save_models(self, path): self.mac.save_models(path) th.save(self.critic.state_dict(), "{}/critic.th".format(path)) th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) th.save(self.agent_optimiser.state_dict(), "{}/agent_opt.th".format(path)) th.save(self.critic_optimiser.state_dict(), "{}/critic_opt.th".format(path)) th.save(self.mixer_optimiser.state_dict(), "{}/mixer_opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) self.critic.load_state_dict(th.load("{}/critic.th".format(path), map_location=lambda storage, loc: storage)) self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) # Not quite right but I don't want to save target networks # self.target_critic.load_state_dict(self.critic.agent.state_dict()) self.target_mixer.load_state_dict(self.mixer.state_dict()) self.agent_optimiser.load_state_dict(th.load("{}/agent_opt.th".format(path), map_location=lambda storage, loc: storage)) self.critic_optimiser.load_state_dict(th.load("{}/critic_opt.th".format(path), map_location=lambda storage, loc: storage)) self.mixer_optimiser.load_state_dict(th.load("{}/mixer_opt.th".format(path), map_location=lambda storage, loc: storage))