示例#1
0
class Train:
    def __init__(self, model, trainloader, valloader, args):
        self.model = model
        self.trainloader = trainloader
        self.valloader = valloader
        self.args = args
        self.start_epoch = 0
        self.best_top1 = 0.0

        # Loss function and Optimizer
        self.loss = None
        self.optimizer = None
        self.create_optimization()

        # Model Loading
        self.load_pretrained_model()
        self.load_checkpoint(self.args.resume_from)

        # Tensorboard Writer
        self.summary_writer = SummaryWriter(log_dir=args.summary_dir)

    def train(self):
        for cur_epoch in range(self.start_epoch, self.args.num_epochs):

            # Initialize tqdm
            tqdm_batch = tqdm(self.trainloader,
                              desc="Epoch-" + str(cur_epoch) + "-")

            # Learning rate adjustment
            self.adjust_learning_rate(self.optimizer, cur_epoch)

            # Meters for tracking the average values
            loss, top1, top5 = AverageTracker(), AverageTracker(
            ), AverageTracker()

            # Set the model to be in training mode (for dropout and batchnorm)
            self.model.train()

            for data, target in tqdm_batch:

                if self.args.cuda:
                    data, target = data.cuda(), target.cuda()
                data_var, target_var = Variable(data), Variable(target)

                # Forward pass
                output = self.model(data_var)
                cur_loss = self.loss(output, target_var)

                # Optimization step
                self.optimizer.zero_grad()
                cur_loss.backward()
                self.optimizer.step()

                # Top-1 and Top-5 Accuracy Calculation
                cur_acc1, cur_acc5 = self.compute_accuracy(output.data,
                                                           target,
                                                           topk=(1, 5))
                loss.update(cur_loss.data[0])
                top1.update(cur_acc1[0])
                top5.update(cur_acc5[0])

            # Summary Writing
            self.summary_writer.add_scalar("epoch-loss", loss.avg, cur_epoch)
            self.summary_writer.add_scalar("epoch-top-1-acc", top1.avg,
                                           cur_epoch)
            self.summary_writer.add_scalar("epoch-top-5-acc", top5.avg,
                                           cur_epoch)

            # Print in console
            tqdm_batch.close()
            print("Epoch-" + str(cur_epoch) + " | " + "loss: " +
                  str(loss.avg) + " - acc-top1: " + str(top1.avg)[:7] +
                  "- acc-top5: " + str(top5.avg)[:7])

            # Evaluate on Validation Set
            if cur_epoch % self.args.test_every == 0 and self.valloader:
                self.test(self.valloader, cur_epoch)

            # Checkpointing
            is_best = top1.avg > self.best_top1
            self.best_top1 = max(top1.avg, self.best_top1)
            self.save_checkpoint(
                {
                    'epoch': cur_epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'best_top1': self.best_top1,
                    'optimizer': self.optimizer.state_dict(),
                }, is_best)

    def test(self, testloader, cur_epoch=-1):
        loss, top1, top5 = AverageTracker(), AverageTracker(), AverageTracker()

        # Set the model to be in testing mode (for dropout and batchnorm)
        self.model.eval()

        for data, target in testloader:
            if self.args.cuda:
                data, target = data.cuda(), target.cuda()
            data_var, target_var = Variable(data, volatile=True), Variable(
                target, volatile=True)

            # Forward pass
            output = self.model(data_var)
            cur_loss = self.loss(output, target_var)

            # Top-1 and Top-5 Accuracy Calculation
            cur_acc1, cur_acc5 = self.compute_accuracy(output.data,
                                                       target,
                                                       topk=(1, 5))
            loss.update(cur_loss.data[0])
            top1.update(cur_acc1[0])
            top5.update(cur_acc5[0])

        if cur_epoch != -1:
            # Summary Writing
            self.summary_writer.add_scalar("test-loss", loss.avg, cur_epoch)
            self.summary_writer.add_scalar("test-top-1-acc", top1.avg,
                                           cur_epoch)
            self.summary_writer.add_scalar("test-top-5-acc", top5.avg,
                                           cur_epoch)

        print("Test Results" + " | " + "loss: " + str(loss.avg) +
              " - acc-top1: " + str(top1.avg)[:7] + "- acc-top5: " +
              str(top5.avg)[:7])

    def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'):
        torch.save(state, self.args.checkpoint_dir + filename)
        if is_best:
            shutil.copyfile(self.args.checkpoint_dir + filename,
                            self.args.checkpoint_dir + 'model_best.pth.tar')

    def compute_accuracy(self, output, target, topk=(1, )):
        """Computes the accuracy@k for the specified values of k"""
        maxk = max(topk)
        batch_size = target.size(0)

        _, idx = output.topk(maxk, 1, True, True)
        idx = idx.t()
        correct = idx.eq(target.view(1, -1).expand_as(idx))

        acc_arr = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            acc_arr.append(correct_k.mul_(1.0 / batch_size))
        return acc_arr

    def adjust_learning_rate(self, optimizer, epoch):
        """Sets the learning rate to the initial LR multiplied by 0.98 every epoch"""
        learning_rate = self.args.learning_rate * (
            self.args.learning_rate_decay**epoch)
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate

    def create_optimization(self):
        self.loss = nn.CrossEntropyLoss()

        if self.args.cuda:
            self.loss.cuda()

        self.optimizer = RMSprop(self.model.parameters(),
                                 self.args.learning_rate,
                                 momentum=self.args.momentum,
                                 weight_decay=self.args.weight_decay)

    def load_pretrained_model(self):
        try:
            print("Loading ImageNet pretrained weights...")
            pretrained_dict = torch.load(self.args.pretrained_path)
            self.model.load_state_dict(pretrained_dict)
            print("ImageNet pretrained weights loaded successfully.\n")
        except:
            print("No ImageNet pretrained weights exist. Skipping...\n")

    def load_checkpoint(self, filename):
        filename = self.args.checkpoint_dir + filename
        try:
            print("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)
            self.start_epoch = checkpoint['epoch']
            self.best_top1 = checkpoint['best_top1']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            print("Checkpoint loaded successfully from '{}' at (epoch {})\n".
                  format(self.args.checkpoint_dir, checkpoint['epoch']))
        except:
            print("No checkpoint exists from '{}'. Skipping...\n".format(
                self.args.checkpoint_dir))
class ModelAndInfo:
    """
    This class contains the model and optional associated information, as well as methods to create
    models and optimizers, move these to GPU and load state from checkpoints. Attributes are:
      config: the model configuration information
      model: the model created based on the config
      optimizer: the optimizer created based on the config and associated with the model
      checkpoint_path: the path load load checkpoint from, can be None
      mean_teacher_model: the mean teacher model, if and as specified by the config
      is_model_adjusted: whether model adjustments (which cannot be done twice) have been applied to model
      is_mean_teacher_model_adjusted: whether model adjustments (which cannot be done twice)
      have been applied to the mean teacher model
      checkpoint_epoch: the training epoch this model was created, if loaded from disk
      model_execution_mode: mode this model will be run in
    """

    MODEL_STATE_DICT_KEY = 'state_dict'
    OPTIIMZER_STATE_DICT_KEY = 'opt_dict'
    MEAN_TEACHER_STATE_DICT_KEY = 'mean_teacher_state_dict'
    EPOCH_KEY = 'epoch'

    def __init__(self,
                 config: ModelConfigBase,
                 model_execution_mode: ModelExecutionMode,
                 checkpoint_path: Optional[Path] = None):
        """
        :param config: the model configuration information
        :param model_execution_mode: mode this model will be run in
        :param checkpoint_path: the path load load checkpoint from, can be None
        """
        self.config = config
        self.checkpoint_path = checkpoint_path
        self.model_execution_mode = model_execution_mode

        self._model = None
        self._mean_teacher_model = None
        self._optimizer = None
        self.checkpoint_epoch = None
        self.is_model_adjusted = False
        self.is_mean_teacher_model_adjusted = False

    @property
    def model(self) -> DeviceAwareModule:
        if not self._model:
            raise ValueError("Model has not been created.")
        return self._model

    @property
    def optimizer(self) -> Optimizer:
        if not self._optimizer:
            raise ValueError("Optimizer has not been created.")
        return self._optimizer

    @property
    def mean_teacher_model(self) -> Optional[DeviceAwareModule]:
        if not self._mean_teacher_model and self.config.compute_mean_teacher_model:
            raise ValueError("Mean teacher model has not been created.")
        return self._mean_teacher_model

    @classmethod
    def _load_checkpoint(cls, model: DeviceAwareModule, checkpoint_path: Path,
                         key_in_state_dict: str, use_gpu: bool) -> int:
        """
        Loads a checkpoint of a model, may be the model or the mean teacher model. Assumes the model
        has already been created, and the checkpoint exists. This does not set checkpoint epoch.
        This method should not be called externally. Use instead try_load_checkpoint_for_model
        or try_load_checkpoint_for_mean_teacher_model
        :param model: model to load weights
        :param key_in_state_dict: the key for the model weights in the checkpoint state dict
        :return checkpoint epoch form the state dict
        """
        logging.info(f"Loading checkpoint {checkpoint_path}")
        # For model debugging, allow loading a GPU trained model onto the CPU. This will clearly only work
        # if the model is small.
        map_location = None if use_gpu else 'cpu'
        checkpoint = torch.load(str(checkpoint_path),
                                map_location=map_location)

        if isinstance(model, torch.nn.DataParallel):
            model.module.load_state_dict(checkpoint[key_in_state_dict])
        else:
            model.load_state_dict(checkpoint[key_in_state_dict])
        return checkpoint[ModelAndInfo.EPOCH_KEY]

    @classmethod
    def _adjust_for_gpus(
            cls, model: DeviceAwareModule, config: ModelConfigBase,
            model_execution_mode: ModelExecutionMode) -> DeviceAwareModule:
        """
        Updates a torch model so that input mini-batches are parallelized across the batch dimension to utilise
        multiple gpus. If model parallel is set to True and execution is in test mode, then model is partitioned to
        perform full volume inference.
        This assumes the model has been created, that the optimizer has not yet been created, and the the model has not
        been adjusted twice. This method should not be called externally. Use instead adjust_model_for_gpus
        or adjust_mean_teacher_model_for_gpus
        :returns Adjusted model
        """
        if config.use_gpu:
            model = model.cuda()
            logging.info(
                "Adjusting the model to use mixed precision training.")
            # If model parallel is set to True, then partition the network across all available gpus.
            if config.use_model_parallel:
                devices = config.get_cuda_devices()
                assert devices is not None  # for mypy
                model.partition_model(devices=devices)  # type: ignore
        else:
            logging.info(
                "Making no adjustments to the model because no GPU was found.")

        # Update model related config attributes (After Model Parallel Activated)
        config.adjust_after_mixed_precision_and_parallel(model)

        # DataParallel enables running the model with multiple gpus by splitting samples across GPUs
        # If the model is used in training mode, data parallel is activated by default.
        # Similarly, if model parallel is not activated, data parallel is used as a backup option
        use_data_parallel = (model_execution_mode == ModelExecutionMode.TRAIN
                             ) or (not config.use_model_parallel)
        if config.use_gpu and use_data_parallel:
            logging.info("Adjusting the model to use DataParallel")
            # Move all layers to the default GPU before activating data parallel.
            # This needs to happen even though we put the model to the GPU at the beginning of the method,
            # but we may have spread it across multiple GPUs later.
            model = model.cuda()
            model = DataParallelModel(model,
                                      device_ids=config.get_cuda_devices())

        return model

    def create_model(self) -> None:
        """
        Creates a model (with temperature scaling) according to the config given.
        """
        self._model = create_model_with_temperature_scaling(self.config)

    def try_load_checkpoint_for_model(self) -> bool:
        """
        Loads a checkpoint of a model. The provided model checkpoint must match the stored model.
        :return True if checkpoint exists and was loaded, False otherwise.
        """
        if self._model is None:
            raise ValueError(
                "Model must be created before it can be adjusted.")

        if not self.checkpoint_path:
            raise ValueError("No checkpoint provided")

        if not self.checkpoint_path.is_file():
            logging.warning(
                f'No checkpoint found at {self.checkpoint_path} current working dir {os.getcwd()}'
            )
            return False

        epoch = ModelAndInfo._load_checkpoint(
            model=self._model,
            checkpoint_path=self.checkpoint_path,
            key_in_state_dict=ModelAndInfo.MODEL_STATE_DICT_KEY,
            use_gpu=self.config.use_gpu)

        logging.info(f"Loaded model from checkpoint (epoch: {epoch})")
        self.checkpoint_epoch = epoch
        return True

    def adjust_model_for_gpus(self) -> None:
        """
        Updates the torch model so that input mini-batches are parallelized across the batch dimension to utilise
        multiple gpus. If model parallel is set to True and execution is in test mode, then model is partitioned to
        perform full volume inference.
        """
        if self._model is None:
            raise ValueError(
                "Model must be created before it can be adjusted.")

        # Adjusting twice causes an error.
        if self.is_model_adjusted:
            logging.debug("model_and_info.is_model_adjusted is already True")

        if self._optimizer:
            raise ValueError(
                "Create an optimizer only after creating and adjusting the model."
            )

        self._model = ModelAndInfo._adjust_for_gpus(
            model=self._model,
            config=self.config,
            model_execution_mode=self.model_execution_mode)

        self.is_model_adjusted = True
        logging.debug("model_and_info.is_model_adjusted set to True")

    def create_summary_and_adjust_model_for_gpus(self) -> None:
        """
        Generates the model summary, which is required for model partitioning across GPUs, and then moves the model to
        GPU with data parallel/model parallel by calling adjust_model_for_gpus.
        """
        if self._model is None:
            raise ValueError(
                "Model must be created before it can be adjusted.")

        if self.config.is_segmentation_model:
            summary_for_segmentation_models(self.config, self._model)
        # Prepare for mixed precision training and data parallelization (no-op if already done).
        # This relies on the information generated in the model summary.
        self.adjust_model_for_gpus()

    def try_create_model_and_load_from_checkpoint(self) -> bool:
        """
        Creates a model as per the config, and loads the parameters from the given checkpoint path.
        Also updates the checkpoint_epoch.
        :return True if checkpoint exists and was loaded, False otherwise.
        """
        self.create_model()
        if self.checkpoint_path:
            # Load the stored model. If there is no checkpoint present, return immediately.
            return self.try_load_checkpoint_for_model()
        return True

    def try_create_model_load_from_checkpoint_and_adjust(self) -> bool:
        """
        Creates a model as per the config, and loads the parameters from the given checkpoint path.
        The model is then adjusted for data parallelism and mixed precision.
        Also updates the checkpoint_epoch.
        :return True if checkpoint exists and was loaded, False otherwise.
        """
        success = self.try_create_model_and_load_from_checkpoint()
        self.create_summary_and_adjust_model_for_gpus()
        return success

    def create_mean_teacher_model(self) -> None:
        """
        Creates a model (with temperature scaling) according to the config given.
        """
        self._mean_teacher_model = create_model_with_temperature_scaling(
            self.config)

    def try_load_checkpoint_for_mean_teacher_model(self) -> bool:
        """
        Loads a checkpoint of a model. The provided model checkpoint must match the stored model.
        :return True if checkpoint exists and was loaded, False otherwise.
        """
        if self._mean_teacher_model is None:
            raise ValueError(
                "Mean teacher model must be created before it can be adjusted."
            )

        if not self.checkpoint_path:
            raise ValueError("No checkpoint provided")

        if not self.checkpoint_path.is_file():
            logging.warning(
                f'No checkpoint found at {self.checkpoint_path} current working dir {os.getcwd()}'
            )
            return False

        epoch = ModelAndInfo._load_checkpoint(
            model=self._mean_teacher_model,
            checkpoint_path=self.checkpoint_path,
            key_in_state_dict=ModelAndInfo.MEAN_TEACHER_STATE_DICT_KEY,
            use_gpu=self.config.use_gpu)

        logging.info(
            f"Loaded mean teacher model from checkpoint (epoch: {epoch})")
        self.checkpoint_epoch = epoch
        return True

    def adjust_mean_teacher_model_for_gpus(self) -> None:
        """
        Updates the torch model so that input mini-batches are parallelized across the batch dimension to utilise
        multiple gpus. If model parallel is set to True and execution is in test mode, then model is partitioned to
        perform full volume inference.
        """
        if self._mean_teacher_model is None:
            raise ValueError(
                "Mean teacher model must be created before it can be adjusted."
            )

        # Adjusting twice causes an error.
        if self.is_mean_teacher_model_adjusted:
            logging.debug(
                "model_and_info.is_mean_teacher_model_adjusted is already True"
            )

        self._mean_teacher_model = ModelAndInfo._adjust_for_gpus(
            model=self._mean_teacher_model,
            config=self.config,
            model_execution_mode=self.model_execution_mode)

        self.is_mean_teacher_model_adjusted = True
        logging.debug(
            "model_and_info.is_mean_teacher_model_adjusted set to True")

    def create_summary_and_adjust_mean_teacher_model_for_gpus(self) -> None:
        """
        Generates the model summary, which is required for model partitioning across GPUs, and then moves the model to
        GPU with data parallel/model parallel by calling adjust_model_for_gpus.
        """
        if self._mean_teacher_model is None:
            raise ValueError(
                "Mean teacher model must be created before it can be adjusted."
            )

        if self.config.is_segmentation_model:
            summary_for_segmentation_models(self.config,
                                            self._mean_teacher_model)
        # Prepare for mixed precision training and data parallelization (no-op if already done).
        # This relies on the information generated in the model summary.
        self.adjust_mean_teacher_model_for_gpus()

    def try_create_mean_teacher_model_and_load_from_checkpoint(self) -> bool:
        """
        Creates a model as per the config, and loads the parameters from the given checkpoint path.
        Also updates the checkpoint_epoch.
        :return True if checkpoint exists and was loaded, False otherwise.
        """
        self.create_mean_teacher_model()
        if self.checkpoint_path:
            # Load the stored model. If there is no checkpoint present, return immediately.
            return self.try_load_checkpoint_for_mean_teacher_model()
        return True

    def try_create_mean_teacher_model_load_from_checkpoint_and_adjust(
            self) -> bool:
        """
        Creates a model as per the config, and loads the parameters from the given checkpoint path.
        The model is then adjusted for data parallelism and mixed precision.
        Also updates the checkpoint_epoch.
        :return True if checkpoint exists and was loaded, False otherwise.
        """
        success = self.try_create_mean_teacher_model_and_load_from_checkpoint()
        self.create_summary_and_adjust_mean_teacher_model_for_gpus()
        return success

    def create_optimizer(self) -> None:
        """
        Creates a torch optimizer for the given model, and stores it as an instance variable in the current object.
        """
        # Make sure model is created before we create optimizer
        if self._model is None:
            raise ValueError(
                "Model checkpoint must be created before optimizer checkpoint can be loaded."
            )

        # Select optimizer type
        if self.config.optimizer_type in [
                OptimizerType.Adam, OptimizerType.AMSGrad
        ]:
            self._optimizer = torch.optim.Adam(
                self._model.parameters(),
                self.config.l_rate,
                self.config.adam_betas,
                self.config.opt_eps,
                self.config.weight_decay,
                amsgrad=self.config.optimizer_type == OptimizerType.AMSGrad)
        elif self.config.optimizer_type == OptimizerType.SGD:
            self._optimizer = torch.optim.SGD(
                self._model.parameters(),
                self.config.l_rate,
                self.config.momentum,
                weight_decay=self.config.weight_decay)
        elif self.config.optimizer_type == OptimizerType.RMSprop:
            self._optimizer = RMSprop(self._model.parameters(),
                                      self.config.l_rate,
                                      self.config.rms_alpha,
                                      self.config.opt_eps,
                                      self.config.weight_decay,
                                      self.config.momentum)
        else:
            raise NotImplementedError(
                f"Optimizer type {self.config.optimizer_type.value} is not implemented"
            )

    def try_load_checkpoint_for_optimizer(self) -> bool:
        """
        Loads a checkpoint of an optimizer.
        :return True if the checkpoint exists and optimizer state loaded, False otherwise
        """

        if self._optimizer is None:
            raise ValueError(
                "Optimizer must be created before optimizer checkpoint can be loaded."
            )

        if not self.checkpoint_path:
            logging.warning("No checkpoint path provided.")
            return False

        if not self.checkpoint_path.is_file():
            logging.warning(
                f'No checkpoint found at {self.checkpoint_path} current working dir {os.getcwd()}'
            )
            return False

        logging.info(f"Loading checkpoint {self.checkpoint_path}")
        # For model debugging, allow loading a GPU trained model onto the CPU. This will clearly only work
        # if the model is small.
        map_location = None if self.config.use_gpu else 'cpu'
        checkpoint = torch.load(str(self.checkpoint_path),
                                map_location=map_location)

        if self._optimizer:
            self._optimizer.load_state_dict(
                checkpoint[ModelAndInfo.OPTIIMZER_STATE_DICT_KEY])

        logging.info(
            f"Loaded optimizer from checkpoint (epoch: {checkpoint[ModelAndInfo.EPOCH_KEY]})"
        )
        self.checkpoint_epoch = checkpoint[ModelAndInfo.EPOCH_KEY]
        return True

    def try_create_optimizer_and_load_from_checkpoint(self) -> bool:
        """
        Creates an optimizer and loads its state from a checkpoint.
        :return True if the checkpoint exists and optimizer state loaded, False otherwise
        """
        self.create_optimizer()
        if self.checkpoint_path:
            return self.try_load_checkpoint_for_optimizer()
        return True

    def save_checkpoint(self, epoch: int) -> None:
        """
        Saves a checkpoint of the current model and optimizer_type parameters in the specified folder
        and uploads it to the output blob storage of the current run context.
        The checkpoint's name for epoch 123 would be 123_checkpoint.pth.tar.
        :param epoch: The last epoch used to train the model.
        """
        logging.getLogger().disabled = True

        model_state_dict = self.model.module.state_dict() \
            if isinstance(self.model, torch.nn.DataParallel) else self.model.state_dict()
        checkpoint_file_path = self.config.get_path_to_checkpoint(epoch)
        info_to_store = {
            ModelAndInfo.EPOCH_KEY: epoch,
            ModelAndInfo.MODEL_STATE_DICT_KEY: model_state_dict,
            ModelAndInfo.OPTIIMZER_STATE_DICT_KEY: self.optimizer.state_dict()
        }
        if self.config.compute_mean_teacher_model:
            assert self.mean_teacher_model is not None  # for mypy, getter has this built in
            mean_teacher_model_state_dict = self.mean_teacher_model.module.state_dict() \
                if isinstance(self.mean_teacher_model, torch.nn.DataParallel) \
                else self.mean_teacher_model.state_dict()
            info_to_store[
                ModelAndInfo.
                MEAN_TEACHER_STATE_DICT_KEY] = mean_teacher_model_state_dict

        torch.save(info_to_store, checkpoint_file_path)
        logging.getLogger().disabled = False
        logging.info(
            "Saved model checkpoint for epoch {epoch} to {checkpoint_file_path}"
        )
def main(args):
    # Select the hardware device to use for inference.
    if torch.cuda.is_available():
        device = torch.device('cuda', torch.cuda.current_device())
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device('cpu')

    # Disable gradient calculations by default.
    torch.set_grad_enabled(False)

    # create checkpoint dir
    os.makedirs(args.checkpoint, exist_ok=True)

    if args.arch == 'hg1':
        model = hg1(pretrained=False)
    elif args.arch == 'hg2':
        model = hg2(pretrained=False)
    elif args.arch == 'hg8':
        model = hg8(pretrained=False)
    else:
        raise Exception('unrecognised model architecture: ' + args.arch)

    model = DataParallel(model).to(device)

    optimizer = RMSprop(model.parameters(),
                        lr=args.lr,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)

    best_acc = 0

    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume)
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_acc = checkpoint['best_acc']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.resume, checkpoint['epoch']))
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'))
        logger.set_names(
            ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])

    # create data loader
    train_dataset = Mpii(args.image_path, is_train=True)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.train_batch,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)

    val_dataset = Mpii(args.image_path, is_train=False)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.test_batch,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)

    # train and eval
    lr = args.lr
    for epoch in trange(args.start_epoch,
                        args.epochs,
                        desc='Overall',
                        ascii=True):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule,
                                  args.gamma)

        # train for one epoch
        train_loss, train_acc = do_training_epoch(train_loader,
                                                  model,
                                                  device,
                                                  Mpii.DATA_INFO,
                                                  optimizer,
                                                  acc_joints=Mpii.ACC_JOINTS)

        # evaluate on validation set
        valid_loss, valid_acc, predictions = do_validation_epoch(
            val_loader,
            model,
            device,
            Mpii.DATA_INFO,
            False,
            acc_joints=Mpii.ACC_JOINTS)

        # print metrics
        tqdm.write(
            f'[{epoch + 1:3d}/{args.epochs:3d}] lr={lr:0.2e} '
            f'train_loss={train_loss:0.4f} train_acc={100 * train_acc:0.2f} '
            f'valid_loss={valid_loss:0.4f} valid_acc={100 * valid_acc:0.2f}')

        # append logger file
        logger.append(
            [epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc])
        logger.plot_to_file(os.path.join(args.checkpoint, 'log.svg'),
                            ['Train Acc', 'Val Acc'])

        # remember best acc and save checkpoint
        is_best = valid_acc > best_acc
        best_acc = max(valid_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            predictions,
            is_best,
            checkpoint=args.checkpoint,
            snapshot=args.snapshot)

    logger.close()
def train_sim(epoch_num=10,
              optim_type='ACGD',
              startPoint=None,
              start_n=0,
              z_dim=128,
              batchsize=64,
              l2_penalty=0.0,
              momentum=0.0,
              log=False,
              loss_name='WGAN',
              model_name='dc',
              model_config=None,
              data_path='None',
              show_iter=100,
              logdir='test',
              dataname='CIFAR10',
              device='cpu',
              gpu_num=1):
    lr_d = 1e-4
    lr_g = 1e-4
    dataset = get_data(dataname=dataname, path=data_path)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batchsize,
                            shuffle=True,
                            num_workers=4)
    D, G = get_model(model_name=model_name, z_dim=z_dim, configs=model_config)
    D.apply(weights_init_d).to(device)
    G.apply(weights_init_g).to(device)

    optim_d = RMSprop(D.parameters(), lr=lr_d)
    optim_g = RMSprop(G.parameters(), lr=lr_g)

    if startPoint is not None:
        chk = torch.load(startPoint)
        D.load_state_dict(chk['D'])
        G.load_state_dict(chk['G'])
        optim_d.load_state_dict(chk['d_optim'])
        optim_g.load_state_dict(chk['g_optim'])
        print('Start from %s' % startPoint)
    if gpu_num > 1:
        D = nn.DataParallel(D, list(range(gpu_num)))
        G = nn.DataParallel(G, list(range(gpu_num)))
    timer = time.time()
    count = 0
    if 'DCGAN' in model_name:
        fixed_noise = torch.randn((64, z_dim, 1, 1), device=device)
    else:
        fixed_noise = torch.randn((64, z_dim), device=device)
    for e in range(epoch_num):
        print('======Epoch: %d / %d======' % (e, epoch_num))
        for real_x in dataloader:
            real_x = real_x[0].to(device)
            d_real = D(real_x)
            if 'DCGAN' in model_name:
                z = torch.randn((d_real.shape[0], z_dim, 1, 1), device=device)
            else:
                z = torch.randn((d_real.shape[0], z_dim), device=device)
            fake_x = G(z)
            d_fake = D(fake_x)
            loss = get_loss(name=loss_name,
                            g_loss=False,
                            d_real=d_real,
                            d_fake=d_fake,
                            l2_weight=l2_penalty,
                            D=D)
            D.zero_grad()
            G.zero_grad()
            loss.backward()
            optim_d.step()
            optim_g.step()

            if count % show_iter == 0:
                time_cost = time.time() - timer
                print('Iter :%d , Loss: %.5f, time: %.3fs' %
                      (count, loss.item(), time_cost))
                timer = time.time()
                with torch.no_grad():
                    fake_img = G(fixed_noise).detach()
                    path = 'figs/%s_%s/' % (dataname, logdir)
                    if not os.path.exists(path):
                        os.makedirs(path)
                    vutils.save_image(fake_img,
                                      path + 'iter_%d.png' % (count + start_n),
                                      normalize=True)
                save_checkpoint(
                    path=logdir,
                    name='%s-%s%.3f_%d.pth' %
                    (optim_type, model_name, lr_g, count + start_n),
                    D=D,
                    G=G,
                    optimizer=optim_d,
                    g_optimizer=optim_g)
            if wandb and log:
                wandb.log({
                    'Real score': d_real.mean().item(),
                    'Fake score': d_fake.mean().item(),
                    'Loss': loss.item()
                })
            count += 1
示例#5
0
class DQNAgent(TrainingAgent):
    def __init__(self, input_shape, action_space, seed, device, model, gamma,
                 alpha, tau, batch_size,update, replay, buffer_size, env,
                 decay = 200, path = 'model',num_epochs= 0, max_step = 50000, learn_interval = 20):

        '''Initialise a DQNAgent Object
        buffer_size : size of replay buffer to sample from
        gamma       : discount rate
        alpha       : learn rate
        replay.     : after which replay buffer loading to be started
        update      : update interval of model parameters every x instances of back propagation
        replay.     : after which replay buffer loading to be started
        learn_interval: tick for learning rate
        '''
        super(DQNAgent,self).__init__( input_shape ,action_space ,seed ,device,model,
                                        gamma, alpha, tau, batch_size, max_step, env,num_epochs ,path)
        self.buffer_size = buffer_size
        self.update = update
        self.replay = replay
        self.interval = learn_interval
        # Q-Network
        self.policy_net = self.model(input_shape, action_space).to(self.device)
        self.target_net = self.model(input_shape, action_space).to(self.device)
        self.optimiser = RMSprop(self.policy_net.parameters(), lr=self.alpha)
        # Replay Memory
        self.memory = ReplayMemory(self.buffer_size, self.batch_size, self.seed, self.device)
        # Timestep
        self.t_step = 0
        self.l_step = 0

        self.EPSILON_START = 1.0
        self.EPSILON_FINAL = 0.02
        self.EPS_DECAY = decay
        self.epsilon_delta = lambda frame_idx: self.EPSILON_FINAL + (self.EPSILON_START - self.EPSILON_FINAL) * exp(-1. * frame_idx / self.EPS_DECAY)

    def step(self, state, action, reward, next_state, done):
        '''
        Step of learning and taking environment action.
        '''

        # Save experience into replay buffer
        self.memory.add(state, action, reward, next_state, done)

        # Learn every update % timestep
        self.t_step = (self.t_step + 1) % self.interval

        if self.t_step == 0:
            # if there are enough samples in the memory, get a random subset and learn
            if len(self.memory) > self.replay:
                experience = self.memory.sample()
                print('learning')
                self.learn(experience)


    def action(self, state, eps=0.):
        ''' Returns action for given state as per current policy'''
        #Unpack the state
        state = torch.from_numpy(state).unsqueeze(0).to(self.device)
        if rand.rand() > eps:
            # Eps Greedy action selections
            action_val = self.policy_net(state)
            return np.argmax(action_val.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_space))

    def learn(self, exp):
        state, action, reward, next_state, done = exp

        # Get expected Q values from Policy Model
        Q_expt_current = self.policy_net(state)
        Q_expt = Q_expt_current.gather(1, action.unsqueeze(1)).squeeze(1)

        # Get max predicted Q values for next state from target model
        Q_target_next = self.target_net(next_state).detach().max(1)[0]
        # Compute Q targets for current states
        Q_target = reward + (self.gamma * Q_target_next * (1 - done))

        # Compute Loss
        loss = torch.nn.functional.mse_loss(Q_expt, Q_target)

        # Minimize loss
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
        self.l_step = (self.l_step +1) % self.update
        if self.t_step == 0:
            self.soft_update(self.policy_net, self.target_net, self.tau)

    def model_dict(self)-> dict:
        ''' To save models'''
        return {'policy_net': self.policy_net.state_dict(), 'target_net': self.target_net.state_dict(),
                'optimizer': self.optimiser.state_dict(), 'num_epoch': self.num_epochs,'scores': self.scores}

    def load_model(self, state_dict,eval = True):
        '''Load Parameters and Model Information from prior training for continuation of training'''
        self.policy_net.load_state_dict(state_dict['policy_net'])
        self.target_net.load_state_dict(state_dict['target_net'])
        self.optimiser.load_state_dict(state_dict['optimizer'])
        self.scores = state_dict['scores']
        if eval:
            self.policy_net.eval()
            self.target_net.eval()
        else:
            self.policy_net.train()
            self.target_net.train()
        #Load the model
        self.num_epochs = state_dict['num_epoch']

    # θ'=θ×τ+θ'×(1−τ)
    def soft_update(self, policy_model, target_model, tau):
        for t_param, p_param in zip(target_model.parameters(), policy_model.parameters()):
            t_param.data.copy_(tau * p_param.data + (1.0 - tau) * t_param.data)

    def train(self, n_episodes=1000,render= False):
        """
        n_episodes: maximum number of training episodes
        Saves Model every 100 Epochs
        """
        filename = get_filename()

        self.env.render(render)
        # Toggles the render on
        for i_episode in range(n_episodes):
            self.num_epochs += 1
            state = self.stack_frames(None, self.reset(), True)
            score = 0
            eps = self.epsilon_delta(self.num_epochs)

            while True:
                action = self.action(state, eps)

                next_state, reward, done, info = self.env.step(action)

                score += reward

                next_state = self.stack_frames(state, next_state, False)

                self.step(state, action, reward, next_state, done)
                state = next_state
                if done:
                    break
            self.scores.append(score)  # save most recent score

            # Every 100 training
            if i_episode % 100 == 0:
                self.save_obj(self.model_dict(), os.path.join(self.path, filename))
                print(f"Creating plot")
                # Plot a figure
                fig = plt.figure()

                # Add a subplot
                # ax = fig.add_subplot(111)

                # Plot the graph
                plt.plot(np.arange(len(self.scores)), self.scores)

                # Add labels
                plt.xlabel('Episode #')
                plt.ylabel('Score')

                # Save the plot
                plt.savefig(f'{i_episode} plot.png')
                print(f"Plot saved")

        # Return the scores.
        return self.scores
class ModelAndInfo:
    """
    This class contains the model and optional associated information, as well as methods to create
    models and optimizers, move these to GPU and load state from checkpoints. Attributes are:
      config: the model configuration information
      model: the model created based on the config
      optimizer: the optimizer created based on the config and associated with the model
      checkpoint_path: the path load load checkpoint from, can be None
      is_mean_teacher: whether this is (intended to be) a mean teacher model
      is_adjusted: whether model adjustments (which cannot be done twice) have been applied
      checkpoint_epoch: the training epoch this model was created, if loaded from disk
      model_execution_mode: mode this model will be run in
    """
    def __init__(self,
                 config: ModelConfigBase,
                 model_execution_mode: ModelExecutionMode,
                 is_mean_teacher: bool = False,
                 checkpoint_path: Optional[Path] = None):
        """
        :param config: the model configuration information
        :param model_execution_mode: mode this model will be run in
        :param is_mean_teacher: whether this is (intended to be) a mean teacher model
        :param checkpoint_path: the path load load checkpoint from, can be None
        """
        self.config = config
        self.is_mean_teacher = is_mean_teacher
        self.checkpoint_path = checkpoint_path
        self.model_execution_mode = model_execution_mode

        self._model = None
        self._optimizer = None
        self.checkpoint_epoch = None
        self.is_adjusted = False

    @property
    def model(self) -> DeviceAwareModule:
        if not self._model:
            raise ValueError("Model has not been created.")
        return self._model

    @property
    def optimizer(self) -> Optimizer:
        if not self._optimizer:
            raise ValueError("Optimizer has not been created.")
        return self._optimizer

    def to_cuda(self) -> None:
        """
        Moves the model to GPU
        """
        if self._model is None:
            raise ValueError(
                "Model must be created before it can be moved to GPU.")
        self._model = self._model.cuda()

    def set_data_parallel(self, device_ids: Optional[List[Any]]) -> None:
        if self._model is None:
            raise ValueError(
                "Model must be created before it can be moved to Data Parellel."
            )
        self._model = DataParallelModel(self._model, device_ids=device_ids)

    def create_model(self) -> None:
        """
        Creates a model (with temperature scaling) according to the config given.
        """
        self._model = create_model_with_temperature_scaling(self.config)

    def try_load_checkpoint_for_model(self) -> bool:
        """
        Loads a checkpoint of a model. The provided model checkpoint must match the stored model.
        :return True if checkpoint exists and was loaded, False otherwise.
        """
        if self._model is None:
            raise ValueError(
                "Model must be created before it can be adjusted.")

        if not self.checkpoint_path:
            raise ValueError("No checkpoint provided")

        if not self.checkpoint_path.is_file():
            logging.warning(
                f'No checkpoint found at {self.checkpoint_path} current working dir {os.getcwd()}'
            )
            return False

        logging.info(f"Loading checkpoint {self.checkpoint_path}")
        # For model debugging, allow loading a GPU trained model onto the CPU. This will clearly only work
        # if the model is small.
        map_location = None if self.config.use_gpu else 'cpu'
        checkpoint = torch.load(str(self.checkpoint_path),
                                map_location=map_location)

        if isinstance(self._model, torch.nn.DataParallel):
            self._model.module.load_state_dict(checkpoint['state_dict'])
        else:
            self._model.load_state_dict(checkpoint['state_dict'])

        logging.info(
            f"Loaded model from checkpoint (epoch: {checkpoint['epoch']})")
        self.checkpoint_epoch = checkpoint['epoch']
        return True

    def adjust_model_for_gpus(self) -> None:
        """
        Updates the torch model so that input mini-batches are parallelized across the batch dimension to utilise
        multiple gpus. If model parallel is set to True and execution is in test mode, then model is partitioned to
        perform full volume inference.
        """
        if self._model is None:
            raise ValueError(
                "Model must be created before it can be adjusted.")

        # Adjusting twice causes an error.
        if self.is_adjusted:
            logging.debug("model_and_info.is_adjusted is already True")

        if self._optimizer:
            raise ValueError(
                "Create an optimizer only after creating and adjusting the model."
            )

        if self.config.use_gpu:
            self.to_cuda()
            logging.info(
                "Adjusting the model to use mixed precision training.")
            # If model parallel is set to True, then partition the network across all available gpus.
            if self.config.use_model_parallel:
                devices = self.config.get_cuda_devices()
                assert devices is not None  # for mypy
                self._model.partition_model(devices=devices)  # type: ignore
        else:
            logging.info(
                "Making no adjustments to the model because no GPU was found.")

        # Update model related config attributes (After Model Parallel Activated)
        self.config.adjust_after_mixed_precision_and_parallel(self._model)

        # DataParallel enables running the model with multiple gpus by splitting samples across GPUs
        # If the model is used in training mode, data parallel is activated by default.
        # Similarly, if model parallel is not activated, data parallel is used as a backup option
        use_data_parallel = (self.model_execution_mode
                             == ModelExecutionMode.TRAIN) or (
                                 not self.config.use_model_parallel)
        if self.config.use_gpu and use_data_parallel:
            logging.info("Adjusting the model to use DataParallel")
            # Move all layers to the default GPU before activating data parallel.
            # This needs to happen even though we put the model to the GPU at the beginning of the method,
            # but we may have spread it across multiple GPUs later.
            self.to_cuda()
            self.set_data_parallel(device_ids=self.config.get_cuda_devices())

        self.is_adjusted = True
        logging.debug("model_and_info.is_adjusted set to True")

    def create_summary_and_adjust_model_for_gpus(self) -> None:
        """
        Generates the model summary, which is required for model partitioning across GPUs, and then moves the model to
        GPU with data parallel/model parallel by calling adjust_model_for_gpus.
        """
        if self._model is None:
            raise ValueError(
                "Model must be created before it can be adjusted.")

        if self.config.is_segmentation_model:
            summary_for_segmentation_models(self.config, self._model)
        # Prepare for mixed precision training and data parallelization (no-op if already done).
        # This relies on the information generated in the model summary.
        self.adjust_model_for_gpus()

    def try_create_model_and_load_from_checkpoint(self) -> bool:
        """
        Creates a model as per the config, and loads the parameters from the given checkpoint path.
        Also updates the checkpoint_epoch.
        :return True if checkpoint exists and was loaded, False otherwise.
        """
        self.create_model()

        # for mypy
        assert self._model

        if self.checkpoint_path:
            # Load the stored model. If there is no checkpoint present, return immediately.
            return self.try_load_checkpoint_for_model()
        return True

    def try_create_model_load_from_checkpoint_and_adjust(self) -> bool:
        """
        Creates a model as per the config, and loads the parameters from the given checkpoint path.
        The model is then adjusted for data parallelism and mixed precision, running in TEST mode.
        Also updates the checkpoint_epoch.
        :return True if checkpoint exists and was loaded, False otherwise.
        """
        success = self.try_create_model_and_load_from_checkpoint()
        self.create_summary_and_adjust_model_for_gpus()
        return success

    def create_optimizer(self) -> None:
        """
        Creates a torch optimizer for the given model, and stores it as an instance variable in the current object.
        """
        # Make sure model is created before we create optimizer
        if self._model is None:
            raise ValueError(
                "Model checkpoint must be created before optimizer checkpoint can be loaded."
            )

        # Select optimizer type
        if self.config.optimizer_type in [
                OptimizerType.Adam, OptimizerType.AMSGrad
        ]:
            self._optimizer = torch.optim.Adam(
                self._model.parameters(),
                self.config.l_rate,
                self.config.adam_betas,
                self.config.opt_eps,
                self.config.weight_decay,
                amsgrad=self.config.optimizer_type == OptimizerType.AMSGrad)
        elif self.config.optimizer_type == OptimizerType.SGD:
            self._optimizer = torch.optim.SGD(
                self._model.parameters(),
                self.config.l_rate,
                self.config.momentum,
                weight_decay=self.config.weight_decay)
        elif self.config.optimizer_type == OptimizerType.RMSprop:
            self._optimizer = RMSprop(self._model.parameters(),
                                      self.config.l_rate,
                                      self.config.rms_alpha,
                                      self.config.opt_eps,
                                      self.config.weight_decay,
                                      self.config.momentum)
        else:
            raise NotImplementedError(
                f"Optimizer type {self.config.optimizer_type.value} is not implemented"
            )

    def try_load_checkpoint_for_optimizer(self) -> bool:
        """
        Loads a checkpoint of an optimizer.
        :return True if the checkpoint exists and optimizer state loaded, False otherwise
        """

        if self._optimizer is None:
            raise ValueError(
                "Optimizer must be created before optimizer checkpoint can be loaded."
            )

        if not self.checkpoint_path:
            logging.warning("No checkpoint path provided.")
            return False

        if not self.checkpoint_path.is_file():
            logging.warning(
                f'No checkpoint found at {self.checkpoint_path} current working dir {os.getcwd()}'
            )
            return False

        logging.info(f"Loading checkpoint {self.checkpoint_path}")
        # For model debugging, allow loading a GPU trained model onto the CPU. This will clearly only work
        # if the model is small.
        map_location = None if self.config.use_gpu else 'cpu'
        checkpoint = torch.load(str(self.checkpoint_path),
                                map_location=map_location)

        if self._optimizer:
            self._optimizer.load_state_dict(checkpoint['opt_dict'])

        logging.info(
            "Loaded optimizer from checkpoint (epoch: {checkpoint['epoch']})")
        self.checkpoint_epoch = checkpoint['epoch']
        return True

    def try_create_optimizer_and_load_from_checkpoint(self) -> bool:
        """
        Creates an optimizer and loads its state from a checkpoint.
        :return True if the checkpoint exists and optimizer state loaded, False otherwise
        """
        self.create_optimizer()
        if self.checkpoint_path:
            return self.try_load_checkpoint_for_optimizer()
        return True
示例#7
0
            loss = lossCri(predictTensor, batchLabels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            if trainingTimes % display_step == 0:
                print("#################")
                print("Predict tensor is ", predictTensor)
                print("Labels are ", batchLabels)
                print("Learning rate is ",
                      optimizer.state_dict()['param_groups'][0]["lr"])
                print("Loss is ", loss)
                print("Training time is ", trainingTimes)
            learning_rate = scheduler.calculateLearningRate()
            state_dic = optimizer.state_dict()
            state_dic["param_groups"][0]["lr"] = float(learning_rate)
            optimizer.load_state_dict(state_dic)
            trainingTimes += 1
            if trainingTimes % save_model_steps == 0:
                torch.save(
                    model.state_dict(),
                    weight_save_path + "ALBERT_" + str(trainingTimes) + ".pth")
else:
    model.eval()
    model.load_state_dict(
        torch.load(weight_save_path + "ALBERT_" + str(testModelSelect) +
                   ".pth"))
    predictLabels = []
    truthLabels = []
    print("POSITIVE SAMPLES PREDICT.")
    k = 0
    for sample in test_positive_samples:
示例#8
0
class OffPGLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.n_agents = args.n_agents
        self.n_actions = args.n_actions
        self.mac = mac
        self.logger = logger

        self.last_target_update_step = 0
        self.critic_training_steps = 0

        self.log_stats_t = -self.args.learner_log_interval - 1

        self.critic = OffPGCritic(scheme, args)
        self.mixer = QMixer(args)
        self.target_critic = copy.deepcopy(self.critic)
        self.target_mixer = copy.deepcopy(self.mixer)

        self.agent_params = list(mac.parameters())
        self.critic_params = list(self.critic.parameters())
        self.mixer_params = list(self.mixer.parameters())
        self.params = self.agent_params + self.critic_params
        self.c_params = self.critic_params + self.mixer_params

        self.agent_optimiser =  RMSprop(params=self.agent_params, lr=args.lr)
        self.critic_optimiser =  RMSprop(params=self.critic_params, lr=args.lr)
        self.mixer_optimiser =  RMSprop(params=self.mixer_params, lr=args.lr)

        print('Mixer Size: ')
        print(get_parameters_num(list(self.c_params)))

    def train(self, batch: EpisodeBatch, t_env: int, log):
        # Get the relevant quantities
        bs = batch.batch_size
        max_t = batch.max_seq_length
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        avail_actions = batch["avail_actions"][:, :-1]
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        mask = mask.repeat(1, 1, self.n_agents).view(-1)
        states = batch["state"][:, :-1]

        #build q
        inputs = self.critic._build_inputs(batch, bs, max_t)
        q_vals = self.critic.forward(inputs).detach()[:, :-1]

        mac_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length - 1):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
        mac_out = th.stack(mac_out, dim=1)  # Concat over time

        # Mask out unavailable actions, renormalise (as in action selection)
        mac_out[avail_actions == 0] = 0
        mac_out = mac_out/mac_out.sum(dim=-1, keepdim=True)
        mac_out[avail_actions == 0] = 0

        # Calculated baseline
        q_taken = th.gather(q_vals, dim=3, index=actions).squeeze(3)
        pi = mac_out.view(-1, self.n_actions)
        baseline = th.sum(mac_out * q_vals, dim=-1).view(-1).detach()

        # Calculate policy grad with mask
        pi_taken = th.gather(pi, dim=1, index=actions.reshape(-1, 1)).squeeze(1)
        pi_taken[mask == 0] = 1.0
        log_pi_taken = th.log(pi_taken)
        coe = self.mixer.k(states).view(-1)

        advantages = (q_taken.view(-1) - baseline)
        # advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        coma_loss = - ((coe * advantages.detach() * log_pi_taken) * mask).sum() / mask.sum()
        
        # dist_entropy = Categorical(pi).entropy().view(-1)
        # dist_entropy[mask == 0] = 0 # fill nan
        # entropy_loss = (dist_entropy * mask).sum() / mask.sum()
 
        # loss = coma_loss - self.args.ent_coef * entropy_loss / entropy_loss.item()
        loss = coma_loss

        # Optimise agents
        self.agent_optimiser.zero_grad()
        loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(self.agent_params, self.args.grad_norm_clip)
        self.agent_optimiser.step()

        #compute parameters sum for debugging
        p_sum = 0.
        for p in self.agent_params:
            p_sum += p.data.abs().sum().item() / 100.0


        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            ts_logged = len(log["critic_loss"])
            for key in ["critic_loss", "critic_grad_norm", "td_error_abs", "q_taken_mean", "target_mean", "q_max_mean", "q_min_mean", "q_max_var", "q_min_var"]:
                self.logger.log_stat(key, sum(log[key])/ts_logged, t_env)
            self.logger.log_stat("q_max_first", log["q_max_first"], t_env)
            self.logger.log_stat("q_min_first", log["q_min_first"], t_env)
            #self.logger.log_stat("advantage_mean", (advantages * mask).sum().item() / mask.sum().item(), t_env)
            # self.logger.log_stat("entropy_loss", entropy_loss.item(), t_env)
            self.logger.log_stat("coma_loss", coma_loss.item(), t_env)
            self.logger.log_stat("agent_grad_norm", grad_norm, t_env)
            self.logger.log_stat("pi_max", (pi.max(dim=1)[0] * mask).sum().item() / mask.sum().item(), t_env)
            self.log_stats_t = t_env

    def train_critic(self, on_batch, best_batch=None, log=None):
        bs = on_batch.batch_size
        max_t = on_batch.max_seq_length
        rewards = on_batch["reward"][:, :-1]
        actions = on_batch["actions"][:, :]
        terminated = on_batch["terminated"][:, :-1].float()
        mask = on_batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = on_batch["avail_actions"][:]
        states = on_batch["state"]

        #build_target_q
        target_inputs = self.target_critic._build_inputs(on_batch, bs, max_t)
        target_q_vals = self.target_critic.forward(target_inputs).detach()
        targets_taken = self.target_mixer(th.gather(target_q_vals, dim=3, index=actions).squeeze(3), states)
        target_q = build_td_lambda_targets(rewards, terminated, mask, targets_taken, self.n_agents, self.args.gamma, self.args.td_lambda).detach()

        inputs = self.critic._build_inputs(on_batch, bs, max_t)


        if best_batch is not None:
            best_target_q, best_inputs, best_mask, best_actions, best_mac_out= self.train_critic_best(best_batch)
            log["best_reward"] = th.mean(best_batch["reward"][:, :-1].squeeze(2).sum(-1), dim=0)
            target_q = th.cat((target_q, best_target_q), dim=0)
            inputs = th.cat((inputs, best_inputs), dim=0)
            mask = th.cat((mask, best_mask), dim=0)
            actions = th.cat((actions, best_actions), dim=0)
            states = th.cat((states, best_batch["state"]), dim=0)

        #train critic
        for t in range(max_t - 1):
            mask_t = mask[:, t:t+1]
            if mask_t.sum() < 0.5:
                continue
            q_vals = self.critic.forward(inputs[:, t:t+1])
            q_ori = q_vals
            q_vals = th.gather(q_vals, 3, index=actions[:, t:t+1]).squeeze(3)
            q_vals = self.mixer.forward(q_vals, states[:, t:t+1])
            target_q_t = target_q[:, t:t+1].detach()
            q_err = (q_vals - target_q_t) * mask_t
            critic_loss = (q_err ** 2).sum() / mask_t.sum()

            self.critic_optimiser.zero_grad()
            self.mixer_optimiser.zero_grad()
            critic_loss.backward()
            grad_norm = th.nn.utils.clip_grad_norm_(self.c_params, self.args.grad_norm_clip)
            self.critic_optimiser.step()
            self.mixer_optimiser.step()
            self.critic_training_steps += 1

            log["critic_loss"].append(critic_loss.item())
            log["critic_grad_norm"].append(grad_norm)
            mask_elems = mask_t.sum().item()
            log["td_error_abs"].append((q_err.abs().sum().item() / mask_elems))
            log["target_mean"].append((target_q_t * mask_t).sum().item() / mask_elems)
            log["q_taken_mean"].append((q_vals * mask_t).sum().item() / mask_elems)
            log["q_max_mean"].append((th.mean(q_ori.max(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems)
            log["q_min_mean"].append((th.mean(q_ori.min(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems)
            log["q_max_var"].append((th.var(q_ori.max(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems)
            log["q_min_var"].append((th.var(q_ori.min(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems)

            if (t == 0):
                log["q_max_first"] = (th.mean(q_ori.max(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems
                log["q_min_first"] = (th.mean(q_ori.min(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems

        #update target network
        if (self.critic_training_steps - self.last_target_update_step) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_step = self.critic_training_steps



    def train_critic_best(self, batch):
        bs = batch.batch_size
        max_t = batch.max_seq_length
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"][:]
        states = batch["state"]

        with th.no_grad():
            # pr for all actions of the episode
            mac_out = []
            self.mac.init_hidden(bs)
            for i in range(max_t):
                agent_outs = self.mac.forward(batch, t=i)
                mac_out.append(agent_outs)
            mac_out = th.stack(mac_out, dim=1).detach()
            # Mask out unavailable actions, renormalise (as in action selection)
            mac_out[avail_actions == 0] = 0
            mac_out = mac_out / mac_out.sum(dim=-1, keepdim=True)
            mac_out[avail_actions == 0] = 0
            critic_mac = th.gather(mac_out, 3, actions).squeeze(3).prod(dim=2, keepdim=True)

            #target_q take
            target_inputs = self.target_critic._build_inputs(batch, bs, max_t)
            target_q_vals = self.target_critic.forward(target_inputs).detach()
            targets_taken = self.target_mixer(th.gather(target_q_vals, dim=3, index=actions).squeeze(3), states)

            #expected q
            exp_q = self.build_exp_q(target_q_vals, mac_out, states).detach()
            # td-error
            targets_taken[:, -1] = targets_taken[:, -1] * (1 - th.sum(terminated, dim=1))
            exp_q[:, -1] = exp_q[:, -1] * (1 - th.sum(terminated, dim=1))
            targets_taken[:, :-1] = targets_taken[:, :-1] * mask
            exp_q[:, :-1] = exp_q[:, :-1] * mask
            td_q = (rewards + self.args.gamma * exp_q[:, 1:] - targets_taken[:, :-1]) * mask

            #compute target
            target_q =  build_target_q(td_q, targets_taken[:, :-1], critic_mac, mask, self.args.gamma, self.args.tb_lambda, self.args.step).detach()

            inputs = self.critic._build_inputs(batch, bs, max_t)

        return target_q, inputs, mask, actions, mac_out


    def build_exp_q(self, target_q_vals, mac_out, states):
        target_exp_q_vals = th.sum(target_q_vals * mac_out, dim=3)
        target_exp_q_vals = self.target_mixer.forward(target_exp_q_vals, states)
        return target_exp_q_vals

    def _update_targets(self):
        self.target_critic.load_state_dict(self.critic.state_dict())
        self.target_mixer.load_state_dict(self.mixer.state_dict())
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.critic.cuda()
        self.mixer.cuda()
        self.target_critic.cuda()
        self.target_mixer.cuda()

    def save_models(self, path):
        self.mac.save_models(path)
        th.save(self.critic.state_dict(), "{}/critic.th".format(path))
        th.save(self.mixer.state_dict(), "{}/mixer.th".format(path))
        th.save(self.agent_optimiser.state_dict(), "{}/agent_opt.th".format(path))
        th.save(self.critic_optimiser.state_dict(), "{}/critic_opt.th".format(path))
        th.save(self.mixer_optimiser.state_dict(), "{}/mixer_opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        self.critic.load_state_dict(th.load("{}/critic.th".format(path), map_location=lambda storage, loc: storage))
        self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage))
        # Not quite right but I don't want to save target networks
       # self.target_critic.load_state_dict(self.critic.agent.state_dict())
        self.target_mixer.load_state_dict(self.mixer.state_dict())
        self.agent_optimiser.load_state_dict(th.load("{}/agent_opt.th".format(path), map_location=lambda storage, loc: storage))
        self.critic_optimiser.load_state_dict(th.load("{}/critic_opt.th".format(path), map_location=lambda storage, loc: storage))
        self.mixer_optimiser.load_state_dict(th.load("{}/mixer_opt.th".format(path), map_location=lambda storage, loc: storage))