def __init__(self, config):
        super().__init__(config)
        # Create an instance from the Model
        self.logger.info("Loading encoder pretrained in imagenet...")
        if self.config.pretrained_encoder:
            pretrained_enc = torch.nn.DataParallel(
                ERFNet(self.config.imagenet_nclasses)).cuda()
            pretrained_enc.load_state_dict(
                torch.load(self.config.pretrained_model_path)['state_dict'])
            pretrained_enc = next(pretrained_enc.children()).features.encoder
        else:
            pretrained_enc = None
        # define erfNet model
        self.model = ERF(self.config, pretrained_enc)
        # Create an instance from the data loader
        self.data_loader = VOCDataLoader(self.config)
        # Create instance from the loss
        self.loss = CrossEntropyLoss(self.config)
        # Create instance from the optimizer
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.config.learning_rate,
            betas=(self.config.betas[0], self.config.betas[1]),
            eps=self.config.eps,
            weight_decay=self.config.weight_decay)
        # Define Scheduler
        lambda1 = lambda epoch: pow(
            (1 - ((epoch - 1) / self.config.max_epoch)), 0.9)
        self.scheduler = lr_scheduler.LambdaLR(self.optimizer,
                                               lr_lambda=lambda1)
        # initialize my counters
        self.current_epoch = 0
        self.current_iteration = 0
        self.best_valid_mean_iou = 0

        # Check is cuda is available or not
        self.is_cuda = torch.cuda.is_available()
        # Construct the flag and make sure that cuda is available
        self.cuda = self.is_cuda & self.config.cuda

        if self.cuda:
            torch.cuda.manual_seed_all(self.config.seed)
            self.device = torch.device("cuda")
            torch.cuda.set_device(self.config.gpu_device)
            self.logger.info("Operation will be on *****GPU-CUDA***** ")
            print_cuda_statistics()

        else:
            self.device = torch.device("cpu")
            torch.manual_seed(self.config.seed)
            self.logger.info("Operation will be on *****CPU***** ")

        self.model = self.model.to(self.device)
        self.loss = self.loss.to(self.device)
        # Model Loading from the latest checkpoint if not found start from scratch.
        self.load_checkpoint(self.config.checkpoint_file)

        # Tensorboard Writer
        self.summary_writer = SummaryWriter(log_dir=self.config.summary_dir,
                                            comment='FCN8s')
Exemple #2
0
    def __init__(self, config):
        super().__init__(config)
        # Create an instance from the Model
        self.model = CondenseNet(self.config)
        # Create an instance from the data loader
        self.data_loader = Cifar10DataLoader(self.config)
        # Create instance from the loss
        self.loss = CrossEntropyLoss()
        # Create instance from the optimizer
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=self.config.learning_rate,
                                         momentum=float(self.config.momentum),
                                         weight_decay=self.config.weight_decay,
                                         nesterov=True)
        # initialize my counters
        self.current_epoch = 0
        self.current_iteration = 0
        self.best_valid_acc = 0
        # Check is cuda is available or not
        self.is_cuda = torch.cuda.is_available()
        # Construct the flag and make sure that cuda is available
        self.cuda = self.is_cuda & self.config.cuda

        if self.cuda:
            self.device = torch.device("cuda")
            torch.cuda.manual_seed_all(self.config.seed)
            torch.cuda.set_device(self.config.gpu_device)
            self.logger.info("Operation will be on *****GPU-CUDA***** ")
            print_cuda_statistics()
        else:
            self.device = torch.device("cpu")
            torch.manual_seed(self.config.seed)
            self.logger.info("Operation will be on *****CPU***** ")

        self.model = self.model.to(self.device)
        self.loss = self.loss.to(self.device)
        # Model Loading from the latest checkpoint if not found start from scratch.
        self.load_checkpoint(self.config.checkpoint_file)
        # Tensorboard Writer
        self.summary_writer = SummaryWriter(log_dir=self.config.summary_dir,
                                            comment='CondenseNet')
Exemple #3
0
class CondenseNetAgent(BaseAgent):
    def __init__(self, config):
        super().__init__(config)
        # Create an instance from the Model
        self.model = CondenseNet(self.config)
        # Create an instance from the data loader
        self.data_loader = Cifar10DataLoader(self.config)
        # Create instance from the loss
        self.loss = CrossEntropyLoss()
        # Create instance from the optimizer
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=self.config.learning_rate,
                                         momentum=float(self.config.momentum),
                                         weight_decay=self.config.weight_decay,
                                         nesterov=True)
        # initialize my counters
        self.current_epoch = 0
        self.current_iteration = 0
        self.best_valid_acc = 0
        # Check is cuda is available or not
        self.is_cuda = torch.cuda.is_available()
        # Construct the flag and make sure that cuda is available
        self.cuda = self.is_cuda & self.config.cuda

        if self.cuda:
            self.device = torch.device("cuda")
            torch.cuda.manual_seed_all(self.config.seed)
            torch.cuda.set_device(self.config.gpu_device)
            self.logger.info("Operation will be on *****GPU-CUDA***** ")
            print_cuda_statistics()
        else:
            self.device = torch.device("cpu")
            torch.manual_seed(self.config.seed)
            self.logger.info("Operation will be on *****CPU***** ")

        self.model = self.model.to(self.device)
        self.loss = self.loss.to(self.device)
        # Model Loading from the latest checkpoint if not found start from scratch.
        self.load_checkpoint(self.config.checkpoint_file)
        # Tensorboard Writer
        self.summary_writer = SummaryWriter(log_dir=self.config.summary_dir,
                                            comment='CondenseNet')

    def save_checkpoint(self, filename='checkpoint.pth.tar', is_best=0):
        """
        Saving the latest checkpoint of the training
        :param filename: filename which will contain the state
        :param is_best: flag is it is the best model
        :return:
        """
        state = {
            'epoch': self.current_epoch,
            'iteration': self.current_iteration,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }
        # Save the state
        torch.save(state, self.config.checkpoint_dir + filename)
        # If it is the best copy it to another file 'model_best.pth.tar'
        if is_best:
            shutil.copyfile(self.config.checkpoint_dir + filename,
                            self.config.checkpoint_dir + 'model_best.pth.tar')

    def load_checkpoint(self, filename):
        filename = self.config.checkpoint_dir + filename
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            self.current_epoch = checkpoint['epoch']
            self.current_iteration = checkpoint['iteration']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])

            self.logger.info(
                "Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                .format(self.config.checkpoint_dir, checkpoint['epoch'],
                        checkpoint['iteration']))
        except OSError as e:
            self.logger.info(
                "No checkpoint exists from '{}'. Skipping...".format(
                    self.config.checkpoint_dir))
            self.logger.info("**First time to train**")

    def run(self):
        """
        This function will the operator
        :return:
        """
        try:
            if self.config.mode == 'test':
                self.validate()
            else:
                self.train()

        except KeyboardInterrupt:
            self.logger.info("You have entered CTRL+C.. Wait to finalize")

    def train(self):
        """
        Main training function, with per-epoch model saving
        """
        for epoch in range(self.current_epoch, self.config.max_epoch):
            self.current_epoch = epoch
            self.train_one_epoch()

            valid_acc = self.validate()
            is_best = valid_acc > self.best_valid_acc
            if is_best:
                self.best_valid_acc = valid_acc
            self.save_checkpoint(is_best=is_best)

    def train_one_epoch(self):
        """
        One epoch training function
        """
        # Initialize tqdm
        tqdm_batch = tqdm(self.data_loader.train_loader,
                          total=self.data_loader.train_iterations,
                          desc="Epoch-{}-".format(self.current_epoch))
        # Set the model to be in training mode
        self.model.train()
        # Initialize your average meters
        epoch_loss = AverageMeter()
        top1_acc = AverageMeter()
        top5_acc = AverageMeter()

        current_batch = 0
        for x, y in tqdm_batch:
            if self.cuda:
                x, y = x.cuda(self.config.async_loading), y.cuda(
                    self.config.async_loading)

            # current iteration over total iterations
            progress = float(
                self.current_epoch * self.data_loader.train_iterations +
                current_batch) / (self.config.max_epoch *
                                  self.data_loader.train_iterations)
            # progress = float(self.current_iteration) / (self.config.max_epoch * self.data_loader.train_iterations)
            x, y = Variable(x), Variable(y)
            lr = adjust_learning_rate(self.optimizer,
                                      self.current_epoch,
                                      self.config,
                                      batch=current_batch,
                                      nBatch=self.data_loader.train_iterations)
            # model
            pred = self.model(x, progress)
            # loss
            cur_loss = self.loss(pred, y)
            if np.isnan(float(cur_loss.item())):
                raise ValueError('Loss is nan during training...')
            # optimizer
            self.optimizer.zero_grad()
            cur_loss.backward()
            self.optimizer.step()

            top1, top5 = cls_accuracy(pred.data, y.data, topk=(1, 5))

            epoch_loss.update(cur_loss.item())
            top1_acc.update(top1.item(), x.size(0))
            top5_acc.update(top5.item(), x.size(0))

            self.current_iteration += 1
            current_batch += 1

            self.summary_writer.add_scalar("epoch/loss", epoch_loss.val,
                                           self.current_iteration)
            self.summary_writer.add_scalar("epoch/accuracy", top1_acc.val,
                                           self.current_iteration)
        tqdm_batch.close()

        self.logger.info("Training at epoch-" + str(self.current_epoch) +
                         " | " + "loss: " + str(epoch_loss.val) +
                         "- Top1 Acc: " + str(top1_acc.val) + "- Top5 Acc: " +
                         str(top5_acc.val))

    def validate(self):
        """
        One epoch validation
        :return:
        """
        tqdm_batch = tqdm(self.data_loader.valid_loader,
                          total=self.data_loader.valid_iterations,
                          desc="Valiation at -{}-".format(self.current_epoch))

        # set the model in training mode
        self.model.eval()

        epoch_loss = AverageMeter()
        top1_acc = AverageMeter()
        top5_acc = AverageMeter()

        for x, y in tqdm_batch:
            if self.cuda:
                x, y = x.cuda(self.config.async_loading), y.cuda(
                    self.config.async_loading)

            x, y = Variable(x), Variable(y)
            # model
            pred = self.model(x)
            # loss
            cur_loss = self.loss(pred, y)
            if np.isnan(float(cur_loss.item())):
                raise ValueError('Loss is nan during validation...')

            top1, top5 = cls_accuracy(pred.data, y.data, topk=(1, 5))
            epoch_loss.update(cur_loss.item())
            top1_acc.update(top1.item(), x.size(0))
            top5_acc.update(top5.item(), x.size(0))

        self.logger.info("Validation results at epoch-" +
                         str(self.current_epoch) + " | " + "loss: " +
                         str(epoch_loss.avg) + "- Top1 Acc: " +
                         str(top1_acc.val) + "- Top5 Acc: " +
                         str(top5_acc.val))

        tqdm_batch.close()

        return top1_acc.avg

    def finalize(self):
        """
        Finalize all the operations of the 2 Main classes of the process the operator and the data loader
        :return:
        """
        self.logger.info(
            "Please wait while finalizing the operation.. Thank you")
        self.save_checkpoint()
        self.summary_writer.export_scalars_to_json("{}all_scalars.json".format(
            self.config.summary_dir))
        self.summary_writer.close()
        self.data_loader.finalize()
Exemple #4
0
    def __init__(self, config):
        super().__init__(config)
        self.config = config

        self.onlineExpert = ComputeECBSSolution(self.config)
        self.dataTransformer = DataTransformer(self.config)
        self.recorder = MonitoringMultiAgentPerformance(self.config)

        self.model = DecentralPlannerNet(self.config)
        self.logger.info("Model: \n".format(print(self.model)))

        # define data_loader
        self.data_loader = DecentralPlannerDataLoader(config=config)

        # define loss
        self.loss = CrossEntropyLoss()
        self.l1_reg = L1Regularizer(self.model)
        self.l2_reg = L2Regularizer(self.model)

        # define optimizers
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=self.config.learning_rate,
                                    weight_decay=self.config.weight_decay)
        print(self.config.weight_decay)
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=self.config.max_epoch, eta_min=1e-6)

        # for param in self.model.parameters():
        #     print(param)

        # for name, param in self.model.state_dict().items():
        #     print(name, param)

        # initialize counter
        self.current_epoch = 0
        self.current_iteration = 0
        self.current_iteration_validStep = 0
        self.rateReachGoal = 0.0

        # set cuda flag
        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda and not self.config.cuda:
            self.logger.info(
                "WARNING: You have a CUDA device, so you should probably enable CUDA"
            )

        self.cuda = self.is_cuda & self.config.cuda

        # set the manual seed for torch
        self.manual_seed = self.config.seed
        if self.cuda:
            torch.cuda.manual_seed_all(self.manual_seed)
            self.config.device = torch.device("cuda")
            torch.cuda.set_device(self.config.gpu_device)
            self.model = self.model.to(self.config.device)
            self.loss = self.loss.to(self.config.device)
            self.logger.info("Program will run on *****GPU-CUDA***** ")
            print_cuda_statistics()
        else:
            self.config.device = torch.device("cpu")
            torch.manual_seed(self.manual_seed)
            self.logger.info("Program will run on *****CPU*****\n")

        # Model Loading from the latest checkpoint if not found start from scratch.
        if self.config.train_TL or self.config.test_general:
            self.load_pretrained_checkpoint(self.config.test_epoch,
                                            lastest=self.config.lastest_epoch,
                                            best=self.config.best_epoch)
        else:
            self.load_checkpoint(self.config.test_epoch,
                                 lastest=self.config.lastest_epoch,
                                 best=self.config.best_epoch)
        # Summary Writer

        self.robot = multiRobotSim(self.config)
        self.switch_toOnlineExpert = False
        self.summary_writer = SummaryWriter(log_dir=self.config.summary_dir,
                                            comment='NerualMAPP')
        self.plot_graph = True
        self.save_dump_input = False
        self.dummy_input = None
        self.dummy_gso = None
        self.time_record = None
Exemple #5
0
class DecentralPlannerAgentLocalWithOnlineExpert(BaseAgent):
    def __init__(self, config):
        super().__init__(config)
        self.config = config

        self.onlineExpert = ComputeECBSSolution(self.config)
        self.dataTransformer = DataTransformer(self.config)
        self.recorder = MonitoringMultiAgentPerformance(self.config)

        self.model = DecentralPlannerNet(self.config)
        self.logger.info("Model: \n".format(print(self.model)))

        # define data_loader
        self.data_loader = DecentralPlannerDataLoader(config=config)

        # define loss
        self.loss = CrossEntropyLoss()
        self.l1_reg = L1Regularizer(self.model)
        self.l2_reg = L2Regularizer(self.model)

        # define optimizers
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=self.config.learning_rate,
                                    weight_decay=self.config.weight_decay)
        print(self.config.weight_decay)
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=self.config.max_epoch, eta_min=1e-6)

        # for param in self.model.parameters():
        #     print(param)

        # for name, param in self.model.state_dict().items():
        #     print(name, param)

        # initialize counter
        self.current_epoch = 0
        self.current_iteration = 0
        self.current_iteration_validStep = 0
        self.rateReachGoal = 0.0

        # set cuda flag
        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda and not self.config.cuda:
            self.logger.info(
                "WARNING: You have a CUDA device, so you should probably enable CUDA"
            )

        self.cuda = self.is_cuda & self.config.cuda

        # set the manual seed for torch
        self.manual_seed = self.config.seed
        if self.cuda:
            torch.cuda.manual_seed_all(self.manual_seed)
            self.config.device = torch.device("cuda")
            torch.cuda.set_device(self.config.gpu_device)
            self.model = self.model.to(self.config.device)
            self.loss = self.loss.to(self.config.device)
            self.logger.info("Program will run on *****GPU-CUDA***** ")
            print_cuda_statistics()
        else:
            self.config.device = torch.device("cpu")
            torch.manual_seed(self.manual_seed)
            self.logger.info("Program will run on *****CPU*****\n")

        # Model Loading from the latest checkpoint if not found start from scratch.
        if self.config.train_TL or self.config.test_general:
            self.load_pretrained_checkpoint(self.config.test_epoch,
                                            lastest=self.config.lastest_epoch,
                                            best=self.config.best_epoch)
        else:
            self.load_checkpoint(self.config.test_epoch,
                                 lastest=self.config.lastest_epoch,
                                 best=self.config.best_epoch)
        # Summary Writer

        self.robot = multiRobotSim(self.config)
        self.switch_toOnlineExpert = False
        self.summary_writer = SummaryWriter(log_dir=self.config.summary_dir,
                                            comment='NerualMAPP')
        self.plot_graph = True
        self.save_dump_input = False
        self.dummy_input = None
        self.dummy_gso = None
        self.time_record = None
        # dummy_input = (torch.zeros(self.config.map_w,self.config.map_w, 3),)
        # self.summary_writer.add_graph(self.model, dummy_input)

    def save_checkpoint(self, epoch, is_best=0, lastest=True):
        """
        Checkpoint saver
        :param file_name: name of the checkpoint file
        :param is_best: boolean flag to indicate whether current checkpoint's accuracy is the best so far
        :return:
        """
        if lastest:
            file_name = "checkpoint.pth.tar"
        else:
            file_name = "checkpoint_{:03d}.pth.tar".format(epoch)
        state = {
            'epoch': self.current_epoch + 1,
            'iteration': self.current_iteration,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
        }

        # Save the state
        torch.save(state, os.path.join(self.config.checkpoint_dir, file_name))
        # If it is the best copy it to another file 'model_best.pth.tar'
        if is_best:
            shutil.copyfile(
                os.path.join(self.config.checkpoint_dir, file_name),
                os.path.join(self.config.checkpoint_dir, 'model_best.pth.tar'))

    def load_pretrained_checkpoint(self, epoch, lastest=True, best=False):
        """
        Latest checkpoint loader
        :param file_name: name of the checkpoint file
        :return:
        """
        if lastest:
            file_name = "checkpoint.pth.tar"
        elif best:
            file_name = "model_best.pth.tar"
        else:
            file_name = "checkpoint_{:03d}.pth.tar".format(epoch)

        filename = os.path.join(self.config.checkpoint_dir_load, file_name)
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            # checkpoint = torch.load(filename)
            checkpoint = torch.load(filename,
                                    map_location='cuda:{}'.format(
                                        self.config.gpu_device))

            self.current_epoch = checkpoint['epoch']

            self.current_iteration = checkpoint['iteration']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            self.logger.info(
                "Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                .format(self.config.checkpoint_dir_load, checkpoint['epoch'],
                        checkpoint['iteration']))

            if self.config.train_TL:
                param_name_GFL = '*GFL*'
                param_name_action = '*actions*'
                assert param_name_GFL != '', 'you must specified the name of the parameters to be re-trained'
                for model_param_name, model_param_value in self.model.named_parameters(
                ):
                    # print("---All layers -- \n", model_param_name)
                    if fnmatch(model_param_name, param_name_GFL) or fnmatch(
                            model_param_name, param_name_action
                    ):  # and model_param_name.endswith('weight'):
                        # print("---retrain layers -- \n", model_param_name)
                        model_param_value.requires_grad = True
                    else:
                        # print("---freezed layers -- \n", model_param_name)
                        model_param_value.requires_grad = False

        except OSError as e:
            self.logger.info(
                "No checkpoint exists from '{}'. Skipping...".format(
                    self.config.checkpoint_dir))
            self.logger.info("**First time to train**")

    def load_checkpoint(self, epoch, lastest=True, best=False):
        """
        Latest checkpoint loader
        :param file_name: name of the checkpoint file
        :return:
        """
        if lastest:
            file_name = "checkpoint.pth.tar"
        elif best:
            file_name = "model_best.pth.tar"
        else:
            file_name = "checkpoint_{:03d}.pth.tar".format(epoch)
        filename = os.path.join(self.config.checkpoint_dir, file_name)
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            # checkpoint = torch.load(filename)
            checkpoint = torch.load(filename,
                                    map_location='cuda:{}'.format(
                                        self.config.gpu_device))

            self.current_epoch = checkpoint['epoch']
            self.current_iteration = checkpoint['iteration']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            self.logger.info(
                "Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                .format(self.config.checkpoint_dir, checkpoint['epoch'],
                        checkpoint['iteration']))
        except OSError as e:
            self.logger.info(
                "No checkpoint exists from '{}'. Skipping...".format(
                    self.config.checkpoint_dir))
            self.logger.info("**First time to train**")

    def run(self):
        """
        The main operator
        :return:
        """
        assert self.config.mode in ['train', 'test']
        try:
            if self.config.mode == 'test':
                print("-------test------------")
                start = time.process_time()
                self.test('test')
                self.time_record = time.process_time() - start
                # self.test('test_trainingSet')
                # self.pipeline_onlineExpert(self.current_epoch)
            else:
                self.train()

        except KeyboardInterrupt:
            self.logger.info("You have entered CTRL+C.. Wait to finalize")

    def train(self):
        """
        Main training loop
        :return:
        """

        for epoch in range(self.current_epoch, self.config.max_epoch + 1):
            # for epoch in range(1, self.config.max_epoch + 1):
            self.current_epoch = epoch
            # TODO: Optional 1: del dataloader before train
            self.train_one_epoch()
            self.logger.info('Train {} on Epoch {}: Learning Rate: {}]'.format(
                self.config.exp_name, self.current_epoch,
                self.scheduler.get_lr()))
            print('Train {} on Epoch {} Learning Rate: {}'.format(
                self.config.exp_name, self.current_epoch,
                self.scheduler.get_lr()))

            rateReachGoal = 0.0
            if self.config.num_agents >= 10:
                if epoch % self.config.validate_every == 0:
                    rateReachGoal = self.test(self.config.mode)
                    self.switch_toOnlineExpert = True
                    self.test('test_trainingSet')
                    # self.test_step()
                    self.save_checkpoint(epoch, lastest=False)
            else:
                if epoch <= 4:
                    rateReachGoal = self.test(self.config.mode)
                    self.switch_toOnlineExpert = True
                    self.test('test_trainingSet')
                    # self.test_step()
                    self.save_checkpoint(epoch, lastest=False)
                elif epoch % self.config.validate_every == 0:
                    rateReachGoal = self.test(self.config.mode)
                    self.switch_toOnlineExpert = True
                    self.test('test_trainingSet')
                    # self.test_step()
                    self.save_checkpoint(epoch, lastest=False)
                    # pass

            is_best = rateReachGoal > self.rateReachGoal
            if is_best:
                self.rateReachGoal = rateReachGoal
            self.save_checkpoint(epoch, is_best=is_best, lastest=True)
            self.scheduler.step()
            # TODO: Optional 2: del dataloader after train
            self.excuation_onlineExport(epoch)

    def excuation_onlineExport(self, epoch):
        if epoch >= self.config.Start_onlineExpert:
            if self.config.num_agents >= 10:
                if epoch % self.config.validate_every == 0:

                    self.pipeline_onlineExpert(epoch)
            else:
                if epoch <= 4:
                    self.pipeline_onlineExpert(epoch)
                elif epoch % self.config.validate_every == 0:
                    self.pipeline_onlineExpert(epoch)

    def pipeline_onlineExpert(self, epoch):
        # TODO: del dataloader
        # create dataloader
        self.onlineExpert.set_up()
        self.onlineExpert.computeSolution()
        self.dataTransformer.set_up(epoch)
        self.dataTransformer.solutionTransformer()
        del self.data_loader
        self.data_loader = DecentralPlannerDataLoader(config=self.config)

    def train_one_epoch(self):
        """
        One epoch of training
        :return:
        """

        # Set the model to be in training mode
        self.model.train()
        # for param in self.model.parameters():
        #     print(param.requires_grad)
        # for batch_idx, (input, target, GSO) in enumerate(self.data_loader.train_loader):
        for batch_idx, (batch_input, batch_target, _, batch_GSO,
                        _) in enumerate(self.data_loader.train_loader):

            inputGPU = batch_input.to(self.config.device)
            gsoGPU = batch_GSO.to(self.config.device)
            # gsoGPU = gsoGPU.unsqueeze(0)
            targetGPU = batch_target.to(self.config.device)
            batch_targetGPU = targetGPU.permute(1, 0, 2)
            self.optimizer.zero_grad()

            # loss
            loss = 0

            # model

            self.model.addGSO(gsoGPU)
            predict = self.model(inputGPU)

            for id_agent in range(self.config.num_agents):
                # for output, target in zip(predict, target):
                batch_predict_currentAgent = predict[id_agent][:]
                batch_target_currentAgent = batch_targetGPU[id_agent][:][:]
                loss = loss + self.loss(
                    batch_predict_currentAgent,
                    torch.max(batch_target_currentAgent, 1)[1])
                # print(loss)

            loss = loss / self.config.num_agents

            loss.backward()
            # for param in self.model.parameters():
            #     print(param.grad)
            self.optimizer.step()
            if batch_idx % self.config.log_interval == 0:
                self.logger.info(
                    'Train {} on Epoch {}: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.
                    format(
                        self.config.exp_name, self.current_epoch,
                        batch_idx * len(inputGPU),
                        len(self.data_loader.train_loader.dataset),
                        100. * batch_idx / len(self.data_loader.train_loader),
                        loss.item()))
            self.current_iteration += 1

            # print(loss)
            log_loss = loss.item()
            self.summary_writer.add_scalar("iteration/loss", log_loss,
                                           self.current_iteration)

    def test_step(self):
        """
        One epoch of testing the accuracy of decision-making of each step
        :return:
        """

        # Set the model to be in training mode
        self.model.eval()

        log_loss_validStep = []
        for batch_idx, (batch_input, batch_target, _, batch_GSO,
                        _) in enumerate(self.data_loader.validStep_loader):

            inputGPU = batch_input.to(self.config.device)
            gsoGPU = batch_GSO.to(self.config.device)
            # gsoGPU = gsoGPU.unsqueeze(0)
            targetGPU = batch_target.to(self.config.device)
            batch_targetGPU = targetGPU.permute(1, 0, 2)
            self.optimizer.zero_grad()

            # loss
            loss_validStep = 0

            # model
            self.model.addGSO(gsoGPU)
            predict = self.model(inputGPU)

            for id_agent in range(self.config.num_agents):
                # for output, target in zip(predict, target):
                batch_predict_currentAgent = predict[id_agent][:]
                batch_target_currentAgent = batch_targetGPU[id_agent][:][:]
                loss_validStep = loss_validStep + self.loss(
                    batch_predict_currentAgent,
                    torch.max(batch_target_currentAgent, 1)[1])
                # print(loss)

            loss_validStep = loss_validStep / self.config.num_agents

            if batch_idx % self.config.log_interval == 0:
                self.logger.info(
                    'ValidStep {} on Epoch {}: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'
                    .format(
                        self.config.exp_name, self.current_epoch,
                        batch_idx * len(inputGPU),
                        len(self.data_loader.validStep_loader.dataset), 100. *
                        batch_idx / len(self.data_loader.validStep_loader),
                        loss_validStep.item()))

            log_loss_validStep.append(loss_validStep.item())

            # self.current_iteration_validStep += 1
            # self.summary_writer.add_scalar("iteration/loss_validStep", loss_validStep.item(), self.current_iteration_validStep)
            # print(loss)

        avg_loss = sum(log_loss_validStep) / len(log_loss_validStep)
        self.summary_writer.add_scalar("epoch/loss_validStep", avg_loss,
                                       self.current_epoch)

    def test(self, mode):
        """
        One cycle of model validation
        :return:
        """
        self.model.eval()
        if mode == 'test':
            dataloader = self.data_loader.test_loader
            label = 'test'
        elif mode == 'test_trainingSet':
            dataloader = self.data_loader.test_trainingSet_loader
            label = 'test_training'
            if self.switch_toOnlineExpert:
                self.robot.createfolder_failure_cases()
        else:
            dataloader = self.data_loader.valid_loader
            label = 'valid'

        size_dataset = dataloader.dataset.data_size
        self.logger.info('\n{} set on {} in {} testing set \n'.format(
            label, self.config.exp_name, size_dataset))

        self.recorder.reset()
        # maxstep = self.robot.getMaxstep()
        with torch.no_grad():
            for input, target, makespan, _, tensor_map in dataloader:

                inputGPU = input.to(self.config.device)
                targetGPU = target.to(self.config.device)

                log_result = self.mutliAgent_ActionPolicy(
                    inputGPU, targetGPU, makespan, tensor_map,
                    self.recorder.count_validset, mode)
                self.recorder.update(self.robot.getMaxstep(), log_result)

        self.summary_writer = self.recorder.summary(label, self.summary_writer,
                                                    self.current_epoch)

        self.logger.info(
            'Accurracy(reachGoalnoCollision): {} \n  '
            'DeteriorationRate(MakeSpan): {} \n  '
            'DeteriorationRate(FlowTime): {} \n  '
            'Rate(collisionPredictedinLoop): {} \n  '
            'Rate(FailedReachGoalbyCollisionShielding): {} \n '.format(
                round(self.recorder.rateReachGoal, 4),
                round(self.recorder.avg_rate_deltaMP, 4),
                round(self.recorder.avg_rate_deltaFT, 4),
                round(self.recorder.rateCollisionPredictedinLoop, 4),
                round(self.recorder.rateFailedReachGoalSH, 4),
            ))

        # if self.config.mode == 'train' and self.plot_graph:
        #     self.summary_writer.add_graph(self.model,None)
        #     self.plot_graph = False

        return self.recorder.rateReachGoal

    def mutliAgent_ActionPolicy(self, input, load_target, makespanTarget,
                                tensor_map, ID_dataset, mode):

        self.robot.setup(input, load_target, makespanTarget, tensor_map,
                         ID_dataset)
        maxstep = self.robot.getMaxstep()

        allReachGoal = False
        noReachGoalbyCollsionShielding = False

        check_collisionFreeSol = False

        check_CollisionHappenedinLoop = False

        check_CollisionPredictedinLoop = False

        findOptimalSolution = False

        compare_makespan, compare_flowtime = self.robot.getOptimalityMetrics()
        currentStep = 0

        Case_start = time.process_time()
        Time_cases_ForwardPass = []
        for step in range(maxstep):
            currentStep = step + 1
            currentState = self.robot.getCurrentState()
            currentStateGPU = currentState.to(self.config.device)

            gso = self.robot.getGSO(step)
            gsoGPU = gso.to(self.config.device)
            self.model.addGSO(gsoGPU)
            # self.model.addGSO(gsoGPU.unsqueeze(0))

            step_start = time.process_time()
            actionVec_predict = self.model(currentStateGPU)

            time_ForwardPass = time.process_time() - step_start

            Time_cases_ForwardPass.append(time_ForwardPass)
            allReachGoal, check_moveCollision, check_predictCollision = self.robot.move(
                actionVec_predict, currentStep)

            if check_moveCollision:
                check_CollisionHappenedinLoop = True

            if check_predictCollision:
                check_CollisionPredictedinLoop = True

            if allReachGoal:
                # findOptimalSolution, compare_makespan, compare_flowtime = self.robot.checkOptimality()
                # print("### Case - {} within maxstep - RealGoal: {} ~~~~~~~~~~~~~~~~~~~~~~".format(ID_dataset, allReachGoal))
                break
            elif currentStep >= (maxstep):
                # print("### Case - {} exceed maxstep - RealGoal: {} - check_moveCollision: {} - check_predictCollision: {}".format(ID_dataset, allReachGoal, check_CollisionHappenedinLoop, check_CollisionPredictedinLoop))
                break

        num_agents_reachgoal = self.robot.count_numAgents_ReachGoal()
        store_GSO, store_communication_radius = self.robot.count_GSO_communcationRadius(
            currentStep)

        if allReachGoal and not check_CollisionHappenedinLoop:
            check_collisionFreeSol = True
            noReachGoalbyCollsionShielding = False
            findOptimalSolution, compare_makespan, compare_flowtime = self.robot.checkOptimality(
                True)
            if self.config.log_anime and self.config.mode == 'test':
                self.robot.save_success_cases('success')

        if currentStep >= (maxstep):
            findOptimalSolution, compare_makespan, compare_flowtime = self.robot.checkOptimality(
                False)
            if mode == 'test_trainingSet' and self.switch_toOnlineExpert:
                self.robot.save_failure_cases()

        if currentStep >= (
                maxstep
        ) and not allReachGoal and check_CollisionPredictedinLoop and not check_CollisionHappenedinLoop:
            findOptimalSolution, compare_makespan, compare_flowtime = self.robot.checkOptimality(
                False)
            # print("### Case - {} -Step{} exceed maxstep({})- ReachGoal: {} due to CollsionShielding \n".format(ID_dataset,currentStep,maxstep, allReachGoal))
            noReachGoalbyCollsionShielding = True
            if self.config.log_anime and self.config.mode == 'test':
                self.robot.save_success_cases('failure')
        time_record = time.process_time() - Case_start

        if self.config.mode == 'test':
            exp_status = "################## {} - End of loop ################## ".format(
                self.config.exp_name)
            case_status = "####### Case{} \t Computation time:{} \t Step{}/{}\t- AllReachGoal-{}\n".format(
                ID_dataset, time_record, currentStep, maxstep, allReachGoal)

            self.logger.info('{} \n {}'.format(exp_status, case_status))

        # if self.config.mode == 'test':
        #     self.robot.draw(ID_dataset)

        # return [allReachGoal, noReachGoalbyCollsionShielding, findOptimalSolution, check_collisionFreeSol, check_CollisionPredictedinLoop, makespanPredict, makespanTarget, flowtimePredict,flowtimeTarget,num_agents_reachgoal]

        return allReachGoal, noReachGoalbyCollsionShielding, findOptimalSolution, check_collisionFreeSol, check_CollisionPredictedinLoop, compare_makespan, compare_flowtime, num_agents_reachgoal, store_GSO, store_communication_radius, time_record, Time_cases_ForwardPass

    def finalize(self):
        """
        Finalizes all the operations of the 2 Main classes of the process, the operator and the data loader
        :return:
        """
        if self.config.mode == 'train':
            print(self.model)
        print("Experiment on {} finished.".format(self.config.exp_name))
        print("Please wait while finalizing the operation.. Thank you")
        # self.save_checkpoint()
        self.summary_writer.export_scalars_to_json("{}all_scalars.json".format(
            self.config.summary_dir))
        self.summary_writer.close()
        self.data_loader.finalize()
        if self.config.mode == 'test':
            print("################## End of testing ################## ")
            print("Computation time:\t{} ".format(self.time_record))
class DecentralPlannerAgentLocal(BaseAgent):

    def __init__(self, config):
        super().__init__(config)
        self.config = config

        self.recorder = MonitoringMultiAgentPerformance(self.config)

        self.model = DecentralPlannerNet(self.config, config.feature_noise_std, config.sybil_attack_count)
        self.logger.info("Model: \n".format(print(self.model)))

        # Add additional noise model parameters to config
        self.map_noise_prob = config.map_noise_prob
        self.map_shift_units = config.map_shift_units
        self.move_noise_std = config.move_noise_std
        self.comm_dropout_param = config.comm_dropout_param

        # Add additonal attack model parameters to config
        self.rogue_agent_count = config.rogue_agent_count

        # define data_loader
        self.data_loader = DecentralPlannerDataLoader(config=config)

        # define loss
        self.loss = CrossEntropyLoss()
        self.l1_reg = L1Regularizer(self.model)
        self.l2_reg = L2Regularizer(self.model)

        # define optimizers
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.learning_rate, weight_decay=self.config.weight_decay)
        print(self.config.weight_decay)
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.config.max_epoch, eta_min=1e-6)

        # for param in self.model.parameters():
        #     print(param)

        # for name, param in self.model.state_dict().items():
        #     print(name, param)

        # initialize counter
        self.current_epoch = 0
        self.current_iteration = 0
        self.current_iteration_validStep = 0
        self.rateReachGoal = 0.0

        # set cuda flag
        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda and not self.config.cuda:
            self.logger.info("WARNING: You have a CUDA device, so you should probably enable CUDA")

        self.cuda = self.is_cuda & self.config.cuda

        # set the manual seed for torch
        self.manual_seed = self.config.seed
        if self.cuda:
            torch.cuda.manual_seed_all(self.manual_seed)
            self.config.device = torch.device("cuda")
            torch.cuda.set_device(self.config.gpu_device)
            self.model = self.model.to(self.config.device)
            self.loss = self.loss.to(self.config.device)
            self.logger.info("Program will run on *****GPU-CUDA***** ")
            print_cuda_statistics()
        else:
            self.config.device = torch.device("cpu")
            torch.manual_seed(self.manual_seed)
            self.logger.info("Program will run on *****CPU*****\n")

        # Model Loading from the latest checkpoint if not found start from scratch.
        if self.config.train_TL or self.config.test_general:
            self.load_pretrained_checkpoint(self.config.test_epoch, lastest=self.config.lastest_epoch, best=self.config.best_epoch)
        else:
            self.load_checkpoint(self.config.test_epoch, lastest=self.config.lastest_epoch, best=self.config.best_epoch)
        # Summary Writer

        self.robot = multiRobotSim(self.config)
        self.summary_writer = SummaryWriter(log_dir=self.config.summary_dir, comment='NerualMAPP')
        self.plot_graph = True
        self.save_dump_input = False
        self.dummy_input = None
        self.dummy_gso = None
        self.time_record = None
        # dummy_input = (torch.zeros(self.config.map_w,self.config.map_w, 3),)
        # self.summary_writer.add_graph(self.model, dummy_input)
        self.results_file = open(self.config.data_root + '/results.txt', 'a+')

    def save_checkpoint(self, epoch, is_best=0, lastest=True):
        """
        Checkpoint saver
        :param file_name: name of the checkpoint file
        :param is_best: boolean flag to indicate whether current checkpoint's accuracy is the best so far
        :return:
        """
        if lastest:
            file_name = "checkpoint.pth.tar"
        else:
            file_name = "checkpoint_{:03d}.pth.tar".format(epoch)
        state = {
            'epoch': self.current_epoch + 1,
            'iteration': self.current_iteration,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
        }

        # Save the state
        torch.save(state, os.path.join(self.config.checkpoint_dir, file_name))
        # If it is the best copy it to another file 'model_best.pth.tar'
        if is_best:
            shutil.copyfile(os.path.join(self.config.checkpoint_dir, file_name),
                            os.path.join(self.config.checkpoint_dir, 'model_best.pth.tar'))

    def load_pretrained_checkpoint(self, epoch, lastest=True, best=False):
        """
        Latest checkpoint loader
        :param file_name: name of the checkpoint file
        :return:
        """
        if lastest:
            file_name = "checkpoint.pth.tar"
        elif best:
            file_name = "model_best.pth.tar"
        else:
            file_name = "checkpoint_{:03d}.pth.tar".format(epoch)

        filename = os.path.join(self.config.checkpoint_dir_load, file_name)
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename, map_location=torch.device('cpu'))
            #checkpoint = torch.load(filename, map_location='cuda:{}'.format(self.config.gpu_device))

            self.current_epoch = checkpoint['epoch']

            self.current_iteration = checkpoint['iteration']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                             .format(self.config.checkpoint_dir_load, checkpoint['epoch'], checkpoint['iteration']))

            if self.config.train_TL:
                param_name_GFL = '*GFL*'
                param_name_action = '*actions*'
                assert param_name_GFL != '', 'you must specified the name of the parameters to be re-trained'
                for model_param_name, model_param_value in self.model.named_parameters():
                    # print("---All layers -- \n", model_param_name)
                    if fnmatch(model_param_name, param_name_GFL) or fnmatch(model_param_name, param_name_action):  # and model_param_name.endswith('weight'):
                        # print("---retrain layers -- \n", model_param_name)
                        model_param_value.requires_grad = True
                    else:
                        # print("---freezed layers -- \n", model_param_name)
                        model_param_value.requires_grad = False


        except OSError as e:
            self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir))
            self.logger.info("**First time to train**")


    def load_checkpoint(self, epoch, lastest=True, best=False):
        """
        Latest checkpoint loader
        :param file_name: name of the checkpoint file
        :return:
        """
        if lastest:
            file_name = "checkpoint.pth.tar"
        elif best:
            file_name = "model_best.pth.tar"
        else:
            file_name = "checkpoint_{:03d}.pth.tar".format(epoch)
        filename = os.path.join(self.config.checkpoint_dir, file_name)
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename, map_location=torch.device('cpu'))
            #checkpoint = torch.load(filename, map_location='cuda:{}'.format(self.config.gpu_device))

            self.current_epoch = checkpoint['epoch']
            self.current_iteration = checkpoint['iteration']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                             .format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration']))
        except OSError as e:
            self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir))
            self.logger.info("**First time to train**")

    def run(self):
        """
        The main operator
        :return:
        """
        assert self.config.mode in ['train', 'test']
        try:
            if self.config.mode == 'test':
                print("-------test------------")
                start = time.process_time()
                self.test('test')
                self.time_record = time.process_time()-start
                # self.test('test_trainingSet')
            else:
                self.train()

        except KeyboardInterrupt:
            self.logger.info("You have entered CTRL+C.. Wait to finalize")

    def train(self):
        """
        Main training loop
        :return:
        """

        for epoch in range(self.current_epoch, self.config.max_epoch + 1):
        # for epoch in range(1, self.config.max_epoch + 1):
            self.current_epoch = epoch
            self.train_one_epoch()
            # self.train_one_epoch_BPTT()
            self.logger.info('Train {} on Epoch {}: Learning Rate: {}]'.format(self.config.exp_name, self.current_epoch, self.scheduler.get_lr()))
            print('Train {} on Epoch {} Learning Rate: {}'.format(self.config.exp_name, self.current_epoch, self.scheduler.get_lr()))

            rateReachGoal = 0.0
            if self.config.num_agents >= 10:
                if epoch % self.config.validate_every == 0:
                    rateReachGoal = self.test(self.config.mode)
                    self.test('test_trainingSet')
                    # self.test_step()
                    self.save_checkpoint(epoch, lastest=False)
            else:
                if epoch <= 4:
                    rateReachGoal = self.test(self.config.mode)
                    self.test('test_trainingSet')
                    # self.test_step()
                    self.save_checkpoint(epoch, lastest=False)
                elif epoch % self.config.validate_every == 0:
                    rateReachGoal =  self.test(self.config.mode)
                    self.test('test_trainingSet')
                    # self.test_step()
                    self.save_checkpoint(epoch, lastest=False)
                    # pass


            is_best = rateReachGoal > self.rateReachGoal
            if is_best:
                self.rateReachGoal = rateReachGoal
            self.save_checkpoint(epoch, is_best=is_best, lastest=True)
            self.scheduler.step()

    def train_one_epoch(self):
        """
        One epoch of training
        :return:
        """

        # Set the model to be in training mode
        self.model.train()
        # for param in self.model.parameters():
        #     print(param.requires_grad)
        # for batch_idx, (input, target, GSO) in enumerate(self.data_loader.train_loader):
        for batch_idx, (batch_input, batch_target, _, batch_GSO, _) in enumerate(self.data_loader.train_loader):

            inputGPU = batch_input.to(self.config.device)
            gsoGPU = batch_GSO.to(self.config.device)
            # gsoGPU = gsoGPU.unsqueeze(0)
            targetGPU = batch_target.to(self.config.device)
            batch_targetGPU = targetGPU.permute(1,0,2)
            self.optimizer.zero_grad()

            # loss
            loss = 0

            # model

            self.model.addGSO(gsoGPU)
            predict = self.model(inputGPU)


            for id_agent in range(self.config.num_agents):
            # for output, target in zip(predict, target):
                batch_predict_currentAgent = predict[id_agent][:]
                batch_target_currentAgent = batch_targetGPU[id_agent][:][:]
                loss = loss + self.loss(batch_predict_currentAgent,  torch.max(batch_target_currentAgent, 1)[1])
                # print(loss)

            loss = loss/self.config.num_agents

            loss.backward()
            # for param in self.model.parameters():
            #     print(param.grad)
            self.optimizer.step()
            if batch_idx % self.config.log_interval == 0:
                self.logger.info('Train {} on Epoch {}: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(self.config.exp_name,
                    self.current_epoch, batch_idx * len(inputGPU), len(self.data_loader.train_loader.dataset),
                           100. * batch_idx / len(self.data_loader.train_loader), loss.item()))
            self.current_iteration += 1

            # print(loss)
            log_loss = loss.item()
            self.summary_writer.add_scalar("iteration/loss", log_loss, self.current_iteration)

    def train_one_epoch_BPTT(self):
        """
        One epoch of training
        :return:
        """

        # Set the model to be in training mode
        self.model.train()

        # seq_length = 5
        seq_length = 10
        # for batch_idx, (input, target, GSO) in enumerate(self.data_loader.train_loader):
        # for batch_idx, (batch_input, batch_GSO, batch_target) in enumerate(self.data_loader.train_loader):
        for batch_idx, (batch_input, batch_target, list_makespan, batch_GSO, _) in enumerate(self.data_loader.train_loader):

            batch_makespan = max(list_makespan)

            batch_size = batch_input.shape[1]
            # print(mask_makespan)
            inputGPU = batch_input.to(self.config.device)
            gsoGPU = batch_GSO.to(self.config.device)

            targetGPU = batch_target.to(self.config.device)

            # for step in range(batch_makespan):
            # self.model.initialize_hidden(batch_size)

            log_loss = []

            for id_seq in range(0, batch_makespan, seq_length):

                # solution # 2
                if id_seq == 0:
                    self.model.initialize_hidden(batch_size)
                else:
                    self.model.detach_hidden()

                if id_seq + seq_length + 1 >= batch_makespan:
                    id_seq_end = batch_makespan
                    self.retain_graph = True
                else:
                    id_seq_end = id_seq + seq_length
                    self.retain_graph = True
                # loss
                loss = 0
                # backpropagate after aggregate loss within certain number of step (5) instead of full makespan
                for step in range(id_seq, id_seq_end):
                    # Back Propagation through time (BPTT)
                    step_inputGPU = inputGPU[step][:]
                    step_targetGPU = targetGPU[step][:]

                    step_gsoGPU = gsoGPU[step][:]
                    step_targetGPU = step_targetGPU.permute(1, 0, 2)
                    self.optimizer.zero_grad()

                    # model
                    self.model.addGSO(step_gsoGPU)
                    step_predict = self.model(step_inputGPU)

                    for id_agent in range(self.config.num_agents):
                        # for output, target in zip(predict, target):
                        batch_predict_currentAgent = step_predict[id_agent][:]
                        batch_target_currentAgent = step_targetGPU[id_agent][:]
                        loss = loss + self.loss(batch_predict_currentAgent,
                                                torch.max(batch_target_currentAgent, 1)[1]) / self.config.num_agents
                        # print(loss)


                # optimizer
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5)
                # solution # 1
                # loss.backward(retain_graph=self.retain_graph)



                #https://github.com/pytorch/examples/blob/e11e0796fc02cc2cd5b6ec2ad7cea21f77e25402/word_language_model/main.py#L155
                # torch.nn.utils.clip_grad_norm(model.parameters(), 0.25)#args.clip)
                # for p in model.parameters():
                #     p.data.add_(-lr, p.grad.data)

                # for param in self.model.parameters():
                #     print(param.grad)

                self.optimizer.step()


                log_loss.append(loss.item())

            avg_loss = sum(log_loss) / len(log_loss)
            if batch_idx % self.config.log_interval == 0:
                self.logger.info('Train {} on Epoch {}: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(self.config.exp_name,
                                                                                                self.current_epoch,
                                                                                                batch_idx * batch_size,
                                                                                                len(self.data_loader.train_loader.dataset),
                                                                                                100. * batch_idx / len(self.data_loader.train_loader),
                                                                                                avg_loss))
            self.current_iteration += 1

            # print(loss)
            # log_loss = loss.item()
            self.summary_writer.add_scalar("iteration/loss", avg_loss, self.current_iteration)

    def test_step(self):
        """
        One epoch of testing the accuracy of decision-making of each step
        :return:
        """

        # Set the model to be in training mode
        self.model.eval()

        log_loss_validStep = []
        for batch_idx, (batch_input, batch_target, _, batch_GSO, _) in enumerate(self.data_loader.validStep_loader):

            inputGPU = batch_input.to(self.config.device)
            gsoGPU = batch_GSO.to(self.config.device)
            # gsoGPU = gsoGPU.unsqueeze(0)
            targetGPU = batch_target.to(self.config.device)
            batch_targetGPU = targetGPU.permute(1, 0, 2)
            self.optimizer.zero_grad()

            # loss
            loss_validStep = 0

            # model
            self.model.addGSO(gsoGPU)
            predict = self.model(inputGPU)

            for id_agent in range(self.config.num_agents):
                # for output, target in zip(predict, target):
                batch_predict_currentAgent = predict[id_agent][:]
                batch_target_currentAgent = batch_targetGPU[id_agent][:][:]
                loss_validStep = loss_validStep + self.loss(batch_predict_currentAgent, torch.max(batch_target_currentAgent, 1)[1])
                # print(loss)

            loss_validStep = loss_validStep/self.config.num_agents

            if batch_idx % self.config.log_interval == 0:
                self.logger.info('ValidStep {} on Epoch {}: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(self.config.exp_name,
                                                                                                self.current_epoch,
                                                                                                batch_idx * len(inputGPU),
                                                                                                len(self.data_loader.validStep_loader.dataset),
                                                                                                100. * batch_idx / len(self.data_loader.validStep_loader),
                                                                                                loss_validStep.item()))

            log_loss_validStep.append(loss_validStep.item())

            # self.current_iteration_validStep += 1
            # self.summary_writer.add_scalar("iteration/loss_validStep", loss_validStep.item(), self.current_iteration_validStep)
            # print(loss)


        avg_loss = sum(log_loss_validStep)/len(log_loss_validStep)
        self.summary_writer.add_scalar("epoch/loss_validStep", avg_loss, self.current_epoch)

    def test(self, mode):
        """
        One cycle of model validation
        :return:
        """
        self.model.eval()
        if mode == 'test':
            dataloader = self.data_loader.test_loader
            label = 'test'
        elif mode == 'test_trainingSet':
            dataloader = self.data_loader.test_trainingSet_loader
            label = 'test_training'
        else:
            dataloader = self.data_loader.valid_loader
            label = 'valid'

        self.logger.info('\n{} set on {} \n'.format(label, self.config.exp_name))

        self.recorder.reset()
        # maxstep = self.robot.getMaxstep()
        with torch.no_grad():
            for input, target, makespan, _, tensor_map in dataloader:

                inputGPU = input.to(self.config.device)
                targetGPU = target.to(self.config.device)


                log_result = self.mutliAgent_ActionPolicy(inputGPU, targetGPU, makespan, tensor_map, self.recorder.count_validset)
                self.recorder.update(self.robot.getMaxstep(), log_result)

        self.summary_writer = self.recorder.summary(label, self.summary_writer, self.current_epoch)


        results = ('Accurracy(reachGoalnoCollision): {} \n  '                        
                         'DeteriorationRate(MakeSpan): {} \n  '
                         'DeteriorationRate(FlowTime): {} \n  '
                         'Rate(collisionPredictedinLoop): {} \n  '
                         'Rate(FailedReachGoalbyCollisionShielding): {} \n '.format(
                                                                  round(self.recorder.rateReachGoal, 4),
                                                                  round(self.recorder.avg_rate_deltaMP, 4),
                                                                  round(self.recorder.avg_rate_deltaFT, 4),
                                                                  round(self.recorder.rateCollisionPredictedinLoop, 4),
                                                                  round(self.recorder.rateFailedReachGoalSH, 4),
                                                                  ))
        self.logger.info(results)
        self.results_file.write('K={}, no OE\n'.format(self.config.nGraphFilterTaps))
        self.results_file.write(results)
        if self.recorder.avg_NonRogueFT:
            nonRogueOut = 'NonRogueFT: {}\n'.format(round(self.recorder.avg_NonRogueFT,4))
            self.results_file.write(nonRogueOut)
            self.logger.info(nonRogueOut)

        # if self.config.mode == 'train' and self.plot_graph:
        #     self.summary_writer.add_graph(self.model,None)
        #     self.plot_graph = False

        return self.recorder.rateReachGoal

    def mutliAgent_ActionPolicy(self, input, load_target, makespanTarget, tensor_map, ID_dataset):

        t0_setup = time.process_time()
        self.robot.setup(input, load_target, makespanTarget, tensor_map, ID_dataset)

        deltaT_setup = time.process_time()-t0_setup
        # print(" Computation time \t-[Step up]-\t\t :{} ".format(deltaT_setup))
        maxstep = self.robot.getMaxstep()

        allReachGoal = False
        noReachGoalbyCollsionShielding = False

        check_collisionFreeSol = False

        check_CollisionHappenedinLoop = False

        check_CollisionPredictedinLoop = False

        findOptimalSolution = False

        compare_makespan, compare_flowtime = self.robot.getOptimalityMetrics()
        currentStep = 0

        Case_start = time.process_time()
        Time_cases_ForwardPass = []
        for step in range(maxstep):
            currentStep = step + 1
            t0_getState = time.process_time()
            currentState = self.robot.getCurrentState()
            currentStateGPU = currentState.to(self.config.device)

            deltaT_getState = time.process_time() - t0_getState
            #print(" Computation time \t-[getState]-\t\t :{} ".format(deltaT_getState))

            t0_getGSO = time.process_time()
            gso = self.robot.getGSO(step)

            deltaT_getGSO = time.process_time() - t0_getGSO
            #print(" Computation time \t-[getGSO]-\t\t :{} ".format(deltaT_getGSO))

            gsoGPU = gso.to(self.config.device)
            pos_agents = self.robot.get_PosAgents()
            self.model.addGSO(gsoGPU)

            # model sensor failure by randomly flipping bits of 
            # map (which consists of 0s to indicate no object and 1s to indicate object)
            if self.map_noise_prob:
                bit_flip_mask = (torch.rand(currentStateGPU.shape) < self.map_noise_prob).to(self.config.device)
                # bits in mask each = 1 with probability specified
                # so xoring with this mask flips bits in map with probability specified
                currentStateGPU = torch.logical_xor(currentStateGPU, bit_flip_mask).float().to(self.config.device)

            # model miscalibration of sensor - shift map bits up by some number of units
            if self.map_shift_units:
                shifted_maps = currentStateGPU[:,:,:,-(currentStateGPU.shape[3] - self.map_shift_units):,:]
                # zero pad the missing rows of the field of view maps
                zero_pad_size = list(currentStateGPU.shape)
                zero_pad_size[3] = self.map_shift_units
                currentStateGPU = torch.cat((shifted_maps, torch.zeros(tuple(zero_pad_size), dtype = torch.float)), dim=3).to(self.config.device)

            
            if self.comm_dropout_param:
                # noise model:
                # create mask to drop out each message between robots i,j with probability 
                # max(1,theta * distance(i,j)/(communication radius))
                distances = squareform(pdist(self.robot.get_PosAgents()[0])) 
                loss_prob = torch.from_numpy(self.comm_dropout_param/self.robot.communicationRadius * distances).to(self.config.device)
                loss_prob = torch.reshape(loss_prob, (1,1,loss_prob.shape[0], loss_prob.shape[1]))
                comm_loss_mask = (torch.rand(loss_prob.shape) > loss_prob).to(self.config.device)
            else:
                comm_loss_mask = None

            step_start = time.process_time()
            actionVec_predict = self.model(currentStateGPU, comm_loss_mask)
            # softmax of actionVec_predict is used to determine probabilities of each of the 5 moves
            # (so at test time, the argmax of actionVec_predict is taken as the move).
            # To simulate control errors (e.g. motors not properly responding to commands, breaking, wheels slipping, etc)
            # we add gaussian noise to actionVec_predict
            if self.move_noise_std:
                for av in actionVec_predict:
                    av += torch.normal(0.0,self.move_noise_std,list(av.shape))

            time_ForwardPass = time.process_time() - step_start



            step_move = time.process_time()
            allReachGoal, check_moveCollision, check_predictCollision = self.robot.move(actionVec_predict, currentStep, self.rogue_agent_count)
            deltaT_move = time.process_time() - step_move
            #print(" Computation time \t-[move]-\t\t :{} ".format(deltaT_move))
            #print(" Computation time \t-[loopStep]-\t\t :{}\n ".format(time.process_time() - t0_getState))
            Time_cases_ForwardPass.append([deltaT_setup, deltaT_getState, deltaT_getGSO, time_ForwardPass, deltaT_move])
            if check_moveCollision:
                check_CollisionHappenedinLoop = True


            if check_predictCollision:
                check_CollisionPredictedinLoop = True

            if allReachGoal:
                # findOptimalSolution, compare_makespan, compare_flowtime = self.robot.checkOptimality()
                # print("### Case - {} within maxstep - RealGoal: {} ~~~~~~~~~~~~~~~~~~~~~~".format(ID_dataset, allReachGoal))
                break
            elif currentStep >= (maxstep):
                # print("### Case - {} exceed maxstep - RealGoal: {} - check_moveCollision: {} - check_predictCollision: {}".format(ID_dataset, allReachGoal, check_CollisionHappenedinLoop, check_CollisionPredictedinLoop))
                break

        num_agents_reachgoal = self.robot.count_numAgents_ReachGoal()
        store_GSO, store_communication_radius = self.robot.count_GSO_communcationRadius(currentStep)

        if allReachGoal and not check_CollisionHappenedinLoop:
            check_collisionFreeSol = True
            noReachGoalbyCollsionShielding = False
            findOptimalSolution, compare_makespan, compare_flowtime = self.robot.checkOptimality(True)
            if self.config.log_anime and self.config.mode == 'test':
                self.robot.save_success_cases('success')

        if currentStep >= (maxstep):
            findOptimalSolution, compare_makespan, compare_flowtime = self.robot.checkOptimality(False)

        if currentStep >= (maxstep) and not allReachGoal and check_CollisionPredictedinLoop and not check_CollisionHappenedinLoop:
            findOptimalSolution, compare_makespan, compare_flowtime = self.robot.checkOptimality(False)
            # print("### Case - {} -Step{} exceed maxstep({})- ReachGoal: {} due to CollsionShielding \n".format(ID_dataset,currentStep,maxstep, allReachGoal))
            noReachGoalbyCollsionShielding = True
            if self.config.log_anime and self.config.mode == 'test':
                self.robot.save_success_cases('failure')
        time_record = time.process_time() - Case_start

        if self.config.mode == 'test':
            exp_status = "################## {} - End of loop ################## ".format(self.config.exp_name)
            case_status = "####### Case{} \t Computation time:{} \t Step{}/{}\t- AllReachGoal-{}\n".format(ID_dataset, time_record,
                                                                                             currentStep,
                                                                                             maxstep, allReachGoal)

            self.logger.info('{} \n {}'.format(exp_status, case_status))
            
            
        # if self.config.mode == 'test':
        #     self.robot.draw(ID_dataset)


        # elif self.config.mode == 'train' and self.current_epoch == self.config.max_epoch:
        #     # self.robot.draw(ID_dataset)
        #     pass

        # return [allReachGoal, noReachGoalbyCollsionShielding, findOptimalSolution, check_collisionFreeSol, check_CollisionPredictedinLoop, makespanPredict, makespanTarget, flowtimePredict,flowtimeTarget,num_agents_reachgoal]

        return allReachGoal, noReachGoalbyCollsionShielding, findOptimalSolution, check_collisionFreeSol, check_CollisionPredictedinLoop, compare_makespan, compare_flowtime, num_agents_reachgoal, store_GSO, store_communication_radius, time_record,Time_cases_ForwardPass, self.robot.nonRogueFlowtimePredict


    def finalize(self):
        """
        Finalizes all the operations of the 2 Main classes of the process, the operator and the data loader
        :return:
        """
        if self.config.mode == 'train':
            print(self.model)
        print("Experiment on {} finished.".format(self.config.exp_name))
        print("Please wait while finalizing the operation.. Thank you")
        # self.save_checkpoint()
        self.summary_writer.export_scalars_to_json("{}all_scalars.json".format(self.config.summary_dir))
        self.summary_writer.close()
        self.data_loader.finalize()
        if self.config.mode == 'test':
            print("################## End of testing ################## ")
            time = "Computation time:{}\n".format(self.time_record)
            print(time)
            self.results_file.write(time)
            self.results_file.close()
Exemple #7
0
class ERFNetAgent(BaseAgent):
    """
    This class will be responsible for handling the whole process of our architecture.
    """
    def __init__(self, config):
        super().__init__(config)
        # Create an instance from the Model
        self.logger.info("Loading encoder pretrained in imagenet...")
        if self.config.pretrained_encoder:
            pretrained_enc = torch.nn.DataParallel(
                ERFNet(self.config.imagenet_nclasses)).cuda()
            pretrained_enc.load_state_dict(
                torch.load(self.config.pretrained_model_path)['state_dict'])
            pretrained_enc = next(pretrained_enc.children()).features.encoder
        else:
            pretrained_enc = None
        # define erfNet model
        self.model = ERF(self.config, pretrained_enc)

        # Create an instance from the data loader
        #self.data_loader = VOCDataLoader(self.config)
        self.data_loader = CityscapesDataLoader(self.config)
        '''
        net_h, net_w = 448, 896
        augment = Compose([RandomHorizontallyFlip(), RandomSized((0.625, 0.75)),
                       RandomRotate(6), RandomCrop((net_h, net_w))])
        
        
        local_path = "./data/Cityscapes"

        self.data_loader = CityscapesLoader(local_path, split="test", is_transform=True, augmentations=None, gt="gtFine")
        '''
        ########################################
        self.color_transform = Colorize(self.config.num_classes)
        self.image_transform = ToPILImage()
        # Create instance from the loss
        self.loss = CrossEntropyLoss(self.config)
        # Create instance from the optimizer
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.config.learning_rate,
            betas=(self.config.betas[0], self.config.betas[1]),
            eps=self.config.eps,
            weight_decay=self.config.weight_decay)
        # Define Scheduler
        lambda1 = lambda epoch: pow(
            (1 - ((epoch - 1) / self.config.max_epoch)), 0.9)
        self.scheduler = lr_scheduler.LambdaLR(self.optimizer,
                                               lr_lambda=lambda1)

        # initialize my counters
        self.current_epoch = 0
        self.current_iteration = 0
        self.best_valid_mean_iou = 0

        # Check is cuda is available or not
        self.is_cuda = torch.cuda.is_available()
        # Construct the flag and make sure that cuda is available
        self.cuda = self.is_cuda & self.config.cuda

        if self.cuda:
            torch.cuda.manual_seed_all(self.config.seed)
            self.device = torch.device("cuda")
            torch.cuda.set_device(self.config.gpu_device)
            self.logger.info("Operation will be on *****GPU-CUDA***** ")
            print_cuda_statistics()

        else:
            self.device = torch.device("cpu")
            torch.manual_seed(self.config.seed)
            self.logger.info("Operation will be on *****CPU***** ")

        self.model = self.model.to(self.device)
        self.loss = self.loss.to(self.device)
        # Model Loading from the latest checkpoint if not found start from scratch.
        self.load_checkpoint(self.config.checkpoint_file)

        # Tensorboard Writer
        self.summary_writer = SummaryWriter(log_dir=self.config.summary_dir,
                                            comment='FCN8s')

        # # scheduler for the optimizer
        # self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
        #                                                             'min', patience=self.config.learning_rate_patience,
        #                                                             min_lr=1e-10, verbose=True)

    def save_checkpoint(self, filename='checkpoint.pth.tar', is_best=0):
        """
        Saving the latest checkpoint of the training
        :param filename: filename which will contain the state
        :param is_best: flag is it is the best model
        :return:
        """
        state = {
            'epoch': self.current_epoch + 1,
            'iteration': self.current_iteration,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }
        # Save the state
        torch.save(state, self.config.checkpoint_dir + filename)
        # If it is the best copy it to another file 'model_best.pth.tar'
        if is_best:
            shutil.copyfile(self.config.checkpoint_dir + filename,
                            self.config.checkpoint_dir + 'model_best.pth.tar')

    def load_checkpoint(self, filename):
        filename = self.config.checkpoint_dir + filename
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            self.current_epoch = checkpoint['epoch']
            self.current_iteration = checkpoint['iteration']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])

            self.logger.info(
                "Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                .format(self.config.checkpoint_dir, checkpoint['epoch'],
                        checkpoint['iteration']))
        except OSError as e:
            self.logger.info(
                "No checkpoint exists from '{}'. Skipping...".format(
                    self.config.checkpoint_dir))
            self.logger.info("**First time to train**")

    def run(self):
        """
        This function will the operator
        :return:
        """
        assert self.config.mode in ['train', 'test', 'random']
        try:
            if self.config.mode == 'test':
                self.test()
            else:
                self.train()

        except KeyboardInterrupt:
            self.logger.info("You have entered CTRL+C.. Wait to finalize")

    def train(self):
        """
        Main training function, with per-epoch model saving
        """

        for epoch in range(self.current_epoch, self.config.max_epoch):
            self.current_epoch = epoch
            self.scheduler.step(epoch)
            self.train_one_epoch()

            valid_mean_iou, valid_loss = self.validate()
            self.scheduler.step(valid_loss)

            is_best = valid_mean_iou > self.best_valid_mean_iou
            if is_best:
                self.best_valid_mean_iou = valid_mean_iou

            self.save_checkpoint(is_best=is_best)

    def train_one_epoch(self):
        """
        One epoch training function
        """
        # Initialize tqdm
        tqdm_batch = tqdm(self.data_loader.train_loader,
                          total=self.data_loader.train_iterations,
                          desc="Epoch-{}-".format(self.current_epoch))

        # Set the model to be in training mode (for batchnorm)
        self.model.train()
        # Initialize your average meters
        epoch_loss = AverageMeter()
        metrics = IOUMetric(self.config.num_classes)

        for x, y in tqdm_batch:
            if self.cuda:
                x, y = x.pin_memory().cuda(
                    non_blocking=self.config.async_loading), y.cuda(
                        non_blocking=self.config.async_loading)
            x, y = Variable(x), Variable(y)
            # model
            pred = self.model(x)
            # loss
            cur_loss = self.loss(pred, y)
            if np.isnan(float(cur_loss.item())):
                raise ValueError('Loss is nan during training...')

            # optimizer
            self.optimizer.zero_grad()
            cur_loss.backward()
            self.optimizer.step()

            epoch_loss.update(cur_loss.item())
            _, pred_max = torch.max(pred, 1)
            metrics.add_batch(pred_max.data.cpu().numpy(),
                              y.data.cpu().numpy())

            self.current_iteration += 1
            # exit(0)

        epoch_acc, _, epoch_iou_class, epoch_mean_iou, _ = metrics.evaluate()
        self.summary_writer.add_scalar("epoch-training/loss", epoch_loss.val,
                                       self.current_iteration)
        self.summary_writer.add_scalar("epoch_training/mean_iou",
                                       epoch_mean_iou, self.current_iteration)
        tqdm_batch.close()

        print("Training Results at epoch-" + str(self.current_epoch) + " | " +
              "loss: " + str(epoch_loss.val) + " - acc-: " + str(epoch_acc) +
              "- mean_iou: " + str(epoch_mean_iou) + "\n iou per class: \n" +
              str(epoch_iou_class))

    def validate(self):
        """
        One epoch validation
        :return:
        """
        tqdm_batch = tqdm(self.data_loader.valid_loader,
                          total=self.data_loader.valid_iterations,
                          desc="Valiation at -{}-".format(self.current_epoch))

        # set the model in training mode
        self.model.eval()

        epoch_loss = AverageMeter()
        metrics = IOUMetric(self.config.num_classes)

        for x, y in tqdm_batch:
            if self.cuda:
                x, y = x.pin_memory().cuda(
                    non_blocking=self.config.async_loading), y.cuda(
                        non_blocking=self.config.async_loading)
            x, y = Variable(x), Variable(y)
            # model
            pred = self.model(x)
            # loss
            cur_loss = self.loss(pred, y)

            if np.isnan(float(cur_loss.item())):
                #print("error")
                raise ValueError('Loss is nan during Validation.')

            _, pred_max = torch.max(pred, 1)
            metrics.add_batch(pred_max.data.cpu().numpy(),
                              y.data.cpu().numpy())

            epoch_loss.update(cur_loss.item())

        epoch_acc, _, epoch_iou_class, epoch_mean_iou, _ = metrics.evaluate()
        self.summary_writer.add_scalar("epoch_validation/loss", epoch_loss.val,
                                       self.current_iteration)
        self.summary_writer.add_scalar("epoch_validation/mean_iou",
                                       epoch_mean_iou, self.current_iteration)

        print("Validation Results at epoch-" + str(self.current_epoch) +
              " | " + "loss: " + str(epoch_loss.val) + " - acc-: " +
              str(epoch_acc) + "- mean_iou: " + str(epoch_mean_iou) +
              "\n iou per class: \n" + str(epoch_iou_class))

        tqdm_batch.close()

        return epoch_mean_iou, epoch_loss.val

    def test(self):
        '''
        test_loader = torch.utils.data.DataLoader(self.data_loader, batch_size = self.config.batch_size
                                                  ,num_workers = self.config.data_loader_workers,
                                                 pin_memory=self.config.pin_memory, shuffle = False)
        test_iterations = (len(self.data_loader) + self.config.batch_size) // self.config.batch_size
        '''
        tqdm_batch = tqdm(self.data_loader.test_loader,
                          total=self.data_loader.test_iterations,
                          desc="Test at -{}-".format(self.current_epoch))

        #tqdm_batch = tqdm(test_loader, total = test_iterations, desc = "Test at -{}-".format(self.current_epoch))
        # set the model in training mode
        self.model.eval()
        '''
        for x, y in tqdm_batch:
            if self.cuda:
                x, y = x.pin_memory().cuda(non_blocking=self.config.async_loading), y.cuda(non_blocking=self.config.async_loading)
            x, y = Variable(x), Variable(y)
            # model
            pred = self.model(x)
            
        '''
        i = 0
        for x_name, x in tqdm_batch:
            if self.cuda:
                x = x.pin_memory().cuda(non_blocking=self.config.async_loading)
            x = Variable(x)
            x = x.unsqueeze(0)
            pred = self.model(x)
            segmented_img = self.image_transform(
                self.color_transform(
                    pred[0].cpu().max(0)[1].data.unsqueeze(0)))

            j = str(i)
            imageio.imsave(j + ".png", segmented_img)
            i += 1
            #imageio.imsave(i+".png" , segmented_img)

        tqdm_batch.close()
        return

    def finalize(self):
        """
        Finalize all the operations of the 2 Main classes of the process the operator and the data loader
        :return:
        """
        print("Please wait while finalizing the operation.. Thank you")
        self.save_checkpoint()
        self.summary_writer.export_scalars_to_json("{}all_scalars.json".format(
            self.config.summary_dir))
        self.summary_writer.close()
        self.data_loader.finalize()