예제 #1
0
    def __init__(self, args):
        self.args = args
        self.best_loss = math.inf
        self.summary = TensorboardSummary(args)
        self.model = get_model(args)

        if args.inference:
            self.model = self.summary.load_network(self.model)

        if args.save_best_model:
            self.best_model = copy.deepcopy(self.model)

        self.optimizer = get_optimizer(self.model, args)
        self.ssim, self.ms_ssim = SSIM(), MS_SSIM()

        if args.trainval:
            self.train_loader, self.val_loader = make_data_loader(
                args, TRAINVAL), make_data_loader(args, TEST)
        else:
            self.train_loader, self.test_loader = make_data_loader(
                args, TRAIN), make_data_loader(args, TEST)

        self.criterion = get_loss_function(args.loss_type)
        self.scheduler = LR_Scheduler(args.lr_policy, args.lr, args.epochs,
                                      len(self.train_loader))

        if args.second_loss:
            self.second_criterion = get_loss_function(MS_SSIM_LOSS)
예제 #2
0
    def __init__(self, args):
        self.args = args
        self.tr_global_step = 1
        self.val_global_step = 1
        self.best_mIoU = 0
        self.num_classes = CityScapes.num_classes
        self.mode = args.mode
        self.segmentation = args.segmentation
        self.reconstruct = args.reconstruct

        self.model = get_model(args)
        self.best_model = copy.deepcopy(self.model)
        self.optimizer = get_optimizer(self.model, args)
        self.summary = TensorboardSummary(args)

        if not args.trainval:
            self.train_loader, self.val_loader = make_data_loader(
                args, 'train'), make_data_loader(args, 'val')
        else:
            self.train_loader, self.val_loader = make_data_loader(
                args, 'trainval'), make_data_loader(args, 'test')

        self.class_weights = get_class_weights(
            self.train_loader, self.num_classes,
            args.weighting_mode) if args.use_class_weights else None
        self.criterion = get_loss_function(args.loss_type, self.class_weights)
        if self.reconstruct:
            self.reconstruction_criterion = get_reconstruction_loss_function(
                args.reconstruct_loss_type)
        self.scheduler = LR_Scheduler(args.lr_policy, args.lr, args.epochs,
                                      len(self.train_loader))
        self.evaluator = Evaluator(self.num_classes)
예제 #3
0
    def __init__(self, args):
        """
        Creates the model, dataloader, loss function, optimizer and tensorboard summary for training.

        Args:
            args (argparse.ArgumentParser): object that contains all the command line arguments.
        """
        self.args = args
        self.best_loss = math.inf
        self.summary = TensorboardSummary(args)
        self.model = get_model(args)

        if self.args.inference:
            self.model = self.summary.load_network(self.model)
            self.inference_loader = make_data_loader(args, INFERENCE)
            self.test_loader = make_data_loader(args, TEST)
        elif self.args.trainval:
            self.train_loader, self.test_loader = make_data_loader(
                args, TRAINVAL), make_data_loader(args, TEST)
        else:
            self.train_loader, self.test_loader = make_data_loader(
                args, TRAIN), make_data_loader(args, TEST)

        if args.save_best_model:
            self.best_model = copy.deepcopy(self.model)

        if not self.args.inference:
            self.criterion = get_loss(args.loss_type)
            self.global_step = tf.compat.v1.train.get_or_create_global_step()
            self.optimizer = get_optimizer(args, self.global_step,
                                           self.train_loader.length)
예제 #4
0
 def load_network(self):
     self.best_model = get_model(self.args)
     self.best_model.load_state_dict(torch.load(''))