Example #1
0
    def __init__(self, args, model, train_loader, test_loader, log=None):
        super().__init__()
        self.args = args
        self.trainingLoader = train_loader
        self.testingLoader = test_loader

        if (log is None):
            self.log = Log(self.args)
        else:
            self.log = log

        loss = TMGLowLoss(args, model).to(args.src_device)
        self.parallel_loss = DataParallelCriterion(loss, args.device_ids)
Example #2
0
def create_single_model(args):
    print('Creating model with\n \
           Input size: {}\n \
           Output size: {}\n \
           Activation {}\n \
           Num layers: {}\n \
           Hidden units per layer: {}\n \
           Using bias: {}\n \
           Using batchnorm {}\n \
           With batchsize {}'.format( \
           args.input_size, args.output_size, args.actv,
           args.num_layers, args.hidden, args.bias, args.bn, args.batch))
    model = args.model(input_size=args.input_size, output_size=args.output_size,
                       actv_type=args.actv,
                       num_layers=args.num_layers,
                       hidden_size=args.hidden, 
                       bias=args.bias, 
                       use_bn=args.bn)

    args.loss = model.loss
    if args.multi_gpu:
        print('Using data parallelism with {} GPUs'.format(args.num_gpu))
        #model = nn.DataParallel(model, device_ids = args.device_ids)

        ###
        model = DataParallelModel(model, device_ids = args.device_ids)
        args.loss = DataParallelCriterion(args.loss, device_ids = args.device_ids)
        ###

    print('Sending model to device {}'.format(args.device))
    model.to(args.device)
    return model
Example #3
0
	def __init__(self, dataloader, hierarchical_transformer, config, i):

		super(Trainer, self).__init__()

		self.iter = i
		self.config = config
		self.cpu = torch.device("cpu")
		self.multi_gpu = len(self.config.gpu_idx) > 1

		self.dataloader = dataloader
		self.word_encoder = WordEncoder.WordEncoder(config, self.dataloader.tweet_field.vocab)
		self.word_pos_encoder = PositionEncoder.PositionEncoder(config, self.config.max_length)
		self.time_delay_encoder = PositionEncoder.PositionEncoder(config, self.config.size)

		# <----------- Check for GPU setting ----------->
		if self.config.gpu:

			self.hierarchical_transformer = DataParallelModel(hierarchical_transformer.cuda())
			self.criterion = DataParallelCriterion(nn.NLLLoss())

		else:
			self.hierarchical_transformer = hierarchical_transformer
			self.criterion = nn.NLLLoss()

		self.adam_optimizer = optim.Adam(self.hierarchical_transformer.parameters(), np.power(self.config.d_model, - 0.5), betas = (self.config.beta_1, self.config.beta_2))
		self.optimizer = Optimizer.Optimizer(self.config, self.adam_optimizer)
Example #4
0
 def init_fn(self, shared_model=None, **kwargs):
     # Create auxiliary models
     self.init_auxiliary()
     if shared_model is not None:
         self.model = shared_model
     else:
         self.model = self.init_model()
         self.model = DataParallelModel(self.model.cuda(),
                                        device_ids=self.gpus)
         # self.model = torch.nn.DataParallel(self.model, device_ids=self.gpus).cuda()
     # Setup a joint optimizer for the 2 models
     self.optimizer = self.init_optimizer(self.options.optim.name)
     self.lr_scheduler = self.init_lr(self.options.optim.lr_scheduler)
     # Create loss functions
     self.criterion = self.init_loss_functions()
     self.criterion = DataParallelCriterion(self.criterion.cuda(),
                                            device_ids=self.gpus)
     # Create AverageMeters for losses
     self.losses = AverageMeter()
     # Evaluators
     # self.evaluators = [Evaluator(self.options, self.logger, self.summary_writer, shared_model=self.model)]
     self.dataset_size = None
Example #5
0
    def define_criterions(self):
        """Define criterions of losses
        LSGAN loss for GAN losses, L1 loss for cycle, identity losses

        Identity loss is used only for monet2photo task to keep the color context
        If you had missed this in the paper, refer to section [5.2 - photo generation from paintings]
        """
        self.criterionGAN = nn.MSELoss()  # LSGAN losses
        self.criterionCycle = nn.L1Loss()
        self.criterionIdt = nn.L1Loss()

        if opt.parallel and torch.cuda.device_count() > 1:
            self.criterionGAN = DataParallelCriterion(self.criterionGAN)
            self.criterionCycle = DataParallelCriterion(self.criterionCycle)
            self.criterionIdt = DataParallelCriterion(self.criterionIdt)
Example #6
0
def main(args):
    # initialization
    print("Input arguments:")
    for key, val in vars(args).items():
        print("{:16} {}".format(key, val))

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.method))

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    cudnn.benchmark = True

    # conduct seg network
    seg_model = get_model(num_classes=args.num_classes)

    saved_state_dict = torch.load(args.restore_from)
    new_params = seg_model.state_dict().copy()

    if args.init:
        for i in saved_state_dict:
            i_parts = i.split('.')
            if not i_parts[0] == 'fc':
                new_params['encoder.' + '.'.join(i_parts[:])] = saved_state_dict[i]
        seg_model.load_state_dict(new_params)
        print('loading params w/o fc')
    else:
        seg_model.load_state_dict(saved_state_dict)
        print('loading params all')

    model = DataParallelModel(seg_model)
    model.float()
    model.cuda()

    # define dataloader
    train_loader = data.DataLoader(DataGenerator(root=args.root, list_path=args.lst,
                                                    crop_size=args.crop_size, training=True),
                                   batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = data.DataLoader(DataGenerator(root=args.val_root, list_path=args.val_lst,
                                                  crop_size=args.crop_size, training=False),
                                 batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)

    # define criterion & optimizer
    criterion = ABRLovaszLoss(ignore_index=args.ignore_label, only_present=True, cls_p= args.num_classes, cls_h= args.hbody_cls, cls_f= args.fbody_cls)
    criterion = DataParallelCriterion(criterion).cuda()

    optimizer = optim.SGD(
        [{'params': filter(lambda p: p.requires_grad, seg_model.parameters()), 'lr': args.learning_rate}],
        lr=args.learning_rate, momentum=0.9, weight_decay=5e-4)

    # key points
    best_val_mIoU = 0
    best_val_pixAcc = 0
    start = time.time()

    for epoch in range(0, args.epochs):
        print('\n{} | {}'.format(epoch, args.epochs - 1))
        # training
        _ = train(model, train_loader, epoch, criterion, optimizer, writer)

        # validation
        if epoch %10 ==0 or epoch > args.epochs*0.8:
            val_pixacc, val_miou = validation(model, val_loader, epoch, writer)
            # save model
            if val_pixacc > best_val_pixAcc:
                best_val_pixAcc = val_pixacc
            if val_miou > best_val_mIoU:
                best_val_mIoU = val_miou
                model_dir = os.path.join(args.snapshot_dir, args.method + '_miou.pth')
                torch.save(seg_model.state_dict(), model_dir)
                print('Model saved to %s' % model_dir)

    os.rename(model_dir, os.path.join(args.snapshot_dir, args.method + '_miou'+str(best_val_mIoU)+'.pth'))
    print('Complete using', time.time() - start, 'seconds')
    print('Best pixAcc: {} | Best mIoU: {}'.format(best_val_pixAcc, best_val_mIoU))
Example #7
0
def main():
    print("Input arguments:")
    for key, val in vars(args).items():
        print("{:16} {}".format(key, val))

    random.seed(args.seed)
    torch.manual_seed(args.seed)

    writer = SummaryWriter(args.snapshot_dir)
    os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    cudnn.enabled = True

    deeplab = get_segmentation_model("_".join([args.network, args.method]), num_classes=args.num_classes)

    saved_state_dict = torch.load(args.restore_from)
    new_params = deeplab.state_dict().copy()

    if 'wide' in args.network:
        saved_state_dict = saved_state_dict['state_dict']
        if 'vistas' in args.method:
            saved_state_dict = saved_state_dict['body']
            for i in saved_state_dict:
                new_params[i] = saved_state_dict[i]
        else:     
            for i in saved_state_dict:
                i_parts = i.split('.')
                if not 'classifier' in i_parts: 
                    new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
    elif 'mobilenet' in args.network:
        for i in saved_state_dict:
            i_parts = i.split('.')
            if not (i_parts[0]=='features' and i_parts[1]=='18') and not i_parts[0]=='classifier':
                new_params['.'.join(i_parts[0:])] = saved_state_dict[i] 
    else:
        for i in saved_state_dict:
            i_parts = i.split('.')
            if not i_parts[0]=='fc' and not  i_parts[0]=='last_linear' and not  i_parts[0]=='classifier':
                new_params['.'.join(i_parts[0:])] = saved_state_dict[i] 

    if args.start_iters > 0:
        deeplab.load_state_dict(saved_state_dict)
    else:
        deeplab.load_state_dict(new_params)

    model = DataParallelModel(deeplab)
    # model = nn.DataParallel(deeplab)
    model.train()     
    model.float()
    model.cuda()    

    criterion = CriterionCrossEntropy()
    if "dsn" in args.method:
        if args.ohem:
            if args.ohem_single:
                print('use ohem only for the second prediction map.')
                criterion = CriterionOhemDSN_single(thres=args.ohem_thres, min_kept=args.ohem_keep, dsn_weight=float(args.dsn_weight))
            else:
                criterion = CriterionOhemDSN(thres=args.ohem_thres, min_kept=args.ohem_keep, dsn_weight=float(args.dsn_weight), use_weight=True)
        else:
            criterion = CriterionDSN(dsn_weight=float(args.dsn_weight), use_weight=True)


    criterion = DataParallelCriterion(criterion)
    criterion.cuda()
    cudnn.benchmark = True


    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(get_segmentation_dataset(args.dataset, root=args.data_dir, list_path=args.data_list,
                    max_iters=args.num_steps*args.batch_size, crop_size=input_size, 
                    scale=args.random_scale, mirror=args.random_mirror, network=args.network), 
                    batch_size=args.batch_size, shuffle=True, num_workers=1, pin_memory=True)

    optimizer = optim.SGD([{'params': filter(lambda p: p.requires_grad, deeplab.parameters()), 'lr': args.learning_rate }], 
                lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)


    optimizer.zero_grad()

    for i_iter, batch in enumerate(trainloader):
        sys.stdout.flush()
        i_iter += args.start_iters
        images, labels, _, _ = batch
        images = Variable(images.cuda())
        labels = Variable(labels.long().cuda())
        optimizer.zero_grad()
        lr = adjust_learning_rate(optimizer, i_iter)
        if args.fix_lr:
            lr = args.learning_rate
        print('learning_rate: {}'.format(lr))

        if 'gt' in args.method:
            preds = model(images, labels)
        else:
            preds = model(images)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        if i_iter % 100 == 0:
            writer.add_scalar('learning_rate', lr, i_iter)
            writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter)
        print('iter = {} of {} completed, loss = {}'.format(i_iter, args.num_steps, loss.data.cpu().numpy()))

        if i_iter >= args.num_steps-1:
            print('save model ...')
            torch.save(deeplab.state_dict(),osp.join(args.snapshot_dir, 'CS_scenes_'+str(args.num_steps)+'.pth'))
            break

        if i_iter % args.save_pred_every == 0:
            print('taking snapshot ...')
            torch.save(deeplab.state_dict(),osp.join(args.snapshot_dir, 'CS_scenes_'+str(i_iter)+'.pth'))     

    end = timeit.default_timer()
    print(end-start,'seconds')
def main(args):
    # initialization
    print("Input arguments:")
    for key, val in vars(args).items():
        print("{:16} {}".format(key, val))

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.method))

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    cudnn.benchmark = True

    # conduct seg network
    seg_model = get_model(num_classes=args.num_classes)

    saved_state_dict = torch.load(args.restore_from)
    new_params = seg_model.state_dict().copy()

    # if args.init:
    #     for i in saved_state_dict:
    #         i_parts = i.split('.')
    #         if not i_parts[0] == 'fc':
    #             new_params['encoder.' + '.'.join(i_parts[:])] = saved_state_dict[i]
    #     seg_model.load_state_dict(new_params)
    #     print('loading params w/o fc')
    # else:
    #     seg_model.load_state_dict(saved_state_dict)
    #     print('loading params all')

    model = DataParallelModel(seg_model)
    model.float()
    model.cuda()

    # define dataloader
    train_loader = data.DataLoader(TrainGenerator(root=args.root,
                                                  list_path=args.lst,
                                                  crop_size=args.crop_size,
                                                  max_scale=2.0),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=4,
                                   pin_memory=True)

    # define criterion & optimizer
    criterion = ReportLovaszLoss(ignore_index=args.ignore_label,
                                 only_present=True)
    criterion = DataParallelCriterion(criterion).cuda()

    optimizer = optim.SGD(
        [{
            'params': filter(lambda p: p.requires_grad,
                             seg_model.parameters()),
            'lr': args.learning_rate
        }],
        lr=args.learning_rate,
        momentum=0.9,
        weight_decay=5e-4)

    start = time.time()

    for epoch in range(0, args.epochs):
        print('\n{} | {}'.format(epoch, args.epochs - 1))
        # training
        _ = train(model, train_loader, epoch, criterion, optimizer, writer)

        if epoch == args.epochs - 1:
            model_dir = os.path.join(args.snapshot_dir,
                                     args.method + '_final.pth')
            torch.save(seg_model.state_dict(), model_dir)
            print('Model saved to %s' % model_dir)

    print('Complete using', time.time() - start, 'seconds')
    saved_state_dict = torch.load(restore_from)
    new_params = deeplab.state_dict().copy()
    for i in saved_state_dict:
        i_parts = i.split('.')
        if not i_parts[0] == 'fc' and not i_parts[
                0] == 'last_linear' and not i_parts[0] == 'classifier':
            new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
    deeplab.load_state_dict(new_params)

    model = DataParallelModel(deeplab)
    model.train()
    model.float()
    model.cuda()

    criterion = CriterionCrossEntropy()
    criterion = DataParallelCriterion(criterion)
    criterion.cuda()

    train_dataset = DatasetCityscapesAugmentation(root=data_dir,
                                                  list_path=data_list,
                                                  max_iters=num_steps *
                                                  batch_size,
                                                  crop_size=crop_size)
    train_loader = data.DataLoader(dataset=train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=1,
                                   pin_memory=True)

    optimizer = optim.SGD(
        [{
Example #10
0
class TrainFlow(object):
    '''
    Trains recursive encoder attention-driven decoder
    Args:
        args (argparse): object with programs arguments
        model (torch.nn): Glow-TM model
        train_loader (torch.dataloader): dataloader with training cases
        test_loader (torch.dataloader): dataloader with training cases
        log (Log): class for logging console outputs
    '''
    def __init__(self, args, model, train_loader, test_loader, log=None):
        super().__init__()
        self.args = args
        self.trainingLoader = train_loader
        self.testingLoader = test_loader

        if (log is None):
            self.log = Log(self.args)
        else:
            self.log = log

        loss = TMGLowLoss(args, model).to(args.src_device)
        self.parallel_loss = DataParallelCriterion(loss, args.device_ids)

    def trainParallel(self, model, optimizer, tback=1, epoch=0, **kwargs):
        '''
        Trains the model for a single epoch
        Args:
            model (torch.nn.Module): PyTorch model to train
            optimizer (torch.optim): PyTorch optimizer to update the models parameters
            tback (int): number of time-steps to back propagate through in time
            stride (int): The stride the low-fidelity input takes compared to output
            epoch (int): current epoch
        Returns:
            total_loss (float): current loss
        '''
        model.train()
        # Total training loss
        total_loss = 0
        beta = self.args.beta

        print("Beta:", beta)
        optimizer.zero_grad()
        for mbIdx, (input0, target0,
                    lstm_seeds) in enumerate(self.trainingLoader):

            aKey = model.module.initLSTMStates(
                lstm_seeds,
                [target0.size(-2), target0.size(-1)])
            # aKey = model.module.initLSTMStates(torch.LongTensor(input0.size(0)).random_(0, int(1e8)), [target0.size(-2), target0.size(-1)])
            a0 = copy.deepcopy(aKey)

            loss = 0  # Time-series loss
            # Loop of time-steps
            tmax = target0.size(1)
            tback = 10

            input_next = input0[:, :tback].to(self.args.device)
            target_next = target0[:, :tback].to(self.args.device)

            target0_mean = torch.mean(target0, axis=1).to(self.args.device)
            target0_rms = torch.sqrt(
                torch.mean((target0.to(self.args.device) -
                            target0_mean.unsqueeze(1))**2,
                           dim=1)).to(self.args.device)

            # Splits time-series into smaller blocks to calculate back-prop through time
            for i in range(0, tmax // tback):

                input = input_next
                ytarget = target_next

                # Asynch load the next time-series
                if (i + 1 < tmax // tback):
                    input_next = input0[:, (i + 1) * tback:(i + 2) *
                                        tback].cuda(self.args.device,
                                                    non_blocking=True)
                    target_next = target0[:, (i + 1) * tback:(i + 2) *
                                          tback].cuda(self.args.device,
                                                      non_blocking=True)

                loss = 0
                gpu_loss = [0 for i in range(self.args.n_gpu)]
                modelPredictions = TMGLowPredictionItem()
                model.scatterModel()

                for tstep in range(tback):

                    # Model forward
                    outputs = model.sample(input[:, tstep], a0)

                    if (isinstance(outputs, list)):
                        yPred = [output[0] for output in outputs]
                        logp = [output[1] for output in outputs]
                        a0 = [output[2] for output in outputs]
                    else:
                        yPred, logp, a0 = outputs

                    modelPredictions.add(yPred, logp, ytarget[:, tstep])
                    # Recompile recurrent states onto the source device
                    if (self.args.n_gpu > 1):
                        a0 = model.gatherLSTMStates(a0)
                    else:
                        a0 = outputs[2]

                # Compute the reverse KL divergence loss
                outputs = modelPredictions.getOutputs()
                targets = modelPredictions.getTargets()
                loss0 = self.parallel_loss(outputs, targets, target0_mean,
                                           target0_rms)
                if (self.args.n_gpu > 1):
                    gpu_loss = [
                        gpu_loss[i] + loss0[i] for i in range(len(loss0))
                    ]
                else:
                    gpu_loss = [gpu_loss[0] + loss0]

                modelPredictions.clear()
                loss = self.parallel_loss.gather(
                    gpu_loss, output_device=self.args.src_device).mean()
                # Backwards!
                loss.backward()

                # print(getGpuMemoryMap())
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               self.args.max_grad_norm)
                optimizer.step()
                optimizer.zero_grad()
                # Average the LSTM states with the initial state to prevent convergence
                for j in range(len(a0)):
                    a_out, c_out = a0[j]
                    a_key, c_key = aKey[j]
                    a0[j] = (0.5 * a_out.detach() + 0.5 * a_key,
                             0.5 * c_out.detach() + 0.5 * c_key)

                total_loss = total_loss + loss.detach()
                # Sync cuda processes here
                # Note sure if needed, but hopefully makes sure next data is loaded.
                torch.cuda.synchronize()
                torch.cuda.empty_cache()

            # Add loss of time-series to total loss
            # Mini-batch progress log
            if ((mbIdx + 1) % 5 == 0):
                self.log.log(
                    'Train Epoch: {}; Mini-batch: {}/{} ({:.0f}%); \t Current Loss: {:.6f}'
                    .format(epoch, mbIdx, len(self.trainingLoader),
                            100. * mbIdx / len(self.trainingLoader),
                            total_loss))

        return total_loss

    def test(self, model, samples=1, epoch=0, plot=True):
        '''
        Tests the model
        Args:
            model (torch.nn.Module): PyTorch model to test
            stride (int): The stride the low-fidelity input takes compared to output
            samples (int): Number of prediction to sample from the model
            epoch (int): current epoch
            plot (boolean): If to plot two of the predictions or not
        Returns:
            mse (float): mean-squared-error between the predictive mean and targets
        '''
        model.eval()
        # Total test loss
        total_loss = 0
        out_std = model.module.out_std.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
        out_mu = model.module.out_mu.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
        for mbIdx, (input0, target0, u0) in enumerate(self.testingLoader):

            u0 = u0.to(self.args.device).unsqueeze(-1).unsqueeze(-1).unsqueeze(
                -1).unsqueeze(-1)
            u0 = torch.cat((u0 / 2.0, u0 / 2.0, u0**2), dim=2)

            u0in = torch.ones(input0[:, :, :1].size()).to(
                self.args.device) * u0[:, :, 0, :, :].unsqueeze(2)
            # input = torch.cat((input0.to(self.args.device), u0in), dim=2)
            input = input0.to(self.args.device)

            ytarget = out_std * target0.to(self.args.device) + out_mu

            dims = [samples] + list(ytarget.size())
            yPred = torch.zeros(dims).type(ytarget.type())

            # Max number of time steps
            tmax = 40
            # Loop through samples
            for i in range(samples):

                aKey = model.module.initLSTMStates(
                    torch.LongTensor(input.size(0)).random_(0, int(1e8)),
                    [ytarget.size(-2), ytarget.size(-1)])
                a0 = copy.deepcopy(aKey)

                # Loop of time-steps
                model.scatterModel()

                for tstep in range(0, tmax + 1):

                    # Model forward
                    outputs = model.sample(input[:, tstep], a0)
                    yPred0, logp, a0 = model.gather(outputs,
                                                    self.args.src_device)

                    out_std = model.module.out_std.unsqueeze(0).unsqueeze(
                        -1).unsqueeze(-1)
                    out_mu = model.module.out_mu.unsqueeze(0).unsqueeze(
                        -1).unsqueeze(-1)
                    yPredHat = out_std * yPred0 + out_mu
                    yPred[i, :, tstep] = yPredHat.detach()

                    # Average current LSTM states with initial state
                    if (tstep % 10 == 0):
                        for j in range(len(a0)):
                            a_out, c_out = a0[j]
                            a_key, c_key = aKey[j]
                            a0[j] = (0.5 * a_out.detach() + 0.5 * a_key,
                                     0.5 * c_out.detach() + 0.5 * c_key)

            if (plot and mbIdx == 0):
                self.log.log('Plotting predictions.')
                plotVelocityPred(self.args,
                                 input,
                                 yPred,
                                 ytarget,
                                 bidx=0,
                                 stride=4,
                                 epoch=epoch)
                plotVelocityPred(self.args,
                                 input,
                                 yPred,
                                 ytarget,
                                 bidx=1,
                                 stride=4,
                                 epoch=epoch)

            # Summation of the squared error between the mean of the samples and target
            total_loss = total_loss + (torch.pow(
                torch.mean(yPred[:, :, 1:tmax + 1], dim=0) -
                ytarget[:, 1:tmax + 1], 2)).sum().detach()

        # Return the mse
        return total_loss / (self.args.ntest * tmax * yPred.size(-2) *
                             yPred.size(-1))
Example #11
0
class CycleGAN:
    """GANs which use cycle loss to make generators to map the image
     using meaningful connections between domains, which keeps content of the image
     so that generators can train how to do style-transfer without paired dataset.
    """
    def __init__(self,
                 lr=2e-4,
                 betas=(.5, .999),
                 n_epochs=200,
                 ngf=64,
                 ndf=64,
                 lambdaA=10.,
                 lambdaB=10.,
                 lambdaIdt=.5):
        self.learning_rate = lr
        self.betas = betas
        self.n_epochs = n_epochs
        self.ngf = ngf
        self.ndf = ndf
        self.lambdaA = lambdaA  # coefficients of cycle losses
        self.lambdaB = lambdaB
        self.lambdaIdt = lambdaIdt  # coefficient of Identity losses

        self.true_label = torch.tensor(1.)
        self.false_label = torch.tensor(0.)

        # todo: need to add data parallel code
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")

        self.fakeA_pool = ImagePool(opt.image_pool_size)
        self.fakeB_pool = ImagePool(opt.image_pool_size)

        self.define_nets()
        self.define_criterions()
        self.move_to()
        self.define_optimizers()
        self.define_schedulers()

    def define_nets(self):
        """Define generators and discriminators for both directions"""
        self.netG_A = define_G(self.ngf)
        self.netG_B = define_G(self.ngf)
        self.netD_A = define_D(self.ndf)
        self.netD_B = define_D(self.ndf)

        self.G_params = list(self.netG_A.parameters()) + list(
            self.netG_B.parameters())
        self.D_params = list(self.netD_A.parameters()) + list(
            self.netD_B.parameters())

    def define_criterions(self):
        """Define criterions of losses
        LSGAN loss for GAN losses, L1 loss for cycle, identity losses

        Identity loss is used only for monet2photo task to keep the color context
        If you had missed this in the paper, refer to section [5.2 - photo generation from paintings]
        """
        self.criterionGAN = nn.MSELoss()  # LSGAN losses
        self.criterionCycle = nn.L1Loss()
        self.criterionIdt = nn.L1Loss()

        if opt.parallel and torch.cuda.device_count() > 1:
            self.criterionGAN = DataParallelCriterion(self.criterionGAN)
            self.criterionCycle = DataParallelCriterion(self.criterionCycle)
            self.criterionIdt = DataParallelCriterion(self.criterionIdt)

    def define_optimizers(self):
        """Define optimizers"""
        self.optimizerG = optim.Adam(self.G_params,
                                     lr=self.learning_rate,
                                     betas=self.betas)
        self.optimizerD = optim.Adam(self.D_params,
                                     lr=self.learning_rate,
                                     betas=self.betas)

    def define_schedulers(self):
        """Define schedulers
        for <100 epoch, maintain initial learning rate
        and for >=100 epoch, linearly decay to 0"""
        def lambda_rule(epoch):
            return min(1., (epoch - self.n_epochs) /
                       (opt.begin_decay - self.n_epochs + 1))

        self.schedulers = [
            optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
            for optimizer in [self.optimizerG, self.optimizerD]
        ]

    def move_to(self):
        """Move Tensors to cuda if available"""
        self.netG_A = self.netG_A.to(self.device, non_blocking=True)
        self.netG_B = self.netG_B.to(self.device, non_blocking=True)
        self.netD_A = self.netD_A.to(self.device, non_blocking=True)
        self.netD_B = self.netD_B.to(self.device, non_blocking=True)
        self.criterionGAN = self.criterionGAN.to(self.device,
                                                 non_blocking=True)
        self.criterionCycle = self.criterionCycle.to(self.device,
                                                     non_blocking=True)
        self.criterionIdt = self.criterionIdt.to(self.device,
                                                 non_blocking=True)
        self.true_label = self.true_label.to(self.device, non_blocking=True)
        self.false_label = self.false_label.to(self.device, non_blocking=True)

    def backward_G(self):
        """Compute losses and gradients"""
        pred_A = self.netD_A(self.fakeB, parallel=opt.parallel)
        pred_B = self.netD_B(self.fakeA, parallel=opt.parallel)
        self.loss_G_A = self.criterionGAN(
            pred_A, self.true_label.repeat(opt.batch_size,
                                           *pred_A[0][0].size()))
        self.loss_G_B = self.criterionGAN(
            pred_B, self.true_label.repeat(opt.batch_size,
                                           *pred_B[0][0].size()))
        self.loss_cycle_A = self.criterionCycle(self.recoA, self.realA)
        self.loss_cycle_B = self.criterionCycle(self.recoB, self.realB)
        self.loss_idt_A = self.criterionIdt(self.idtA, self.realB)
        self.loss_idt_B = self.criterionIdt(self.idtB, self.realA)

        self.loss_G = (self.loss_G_A + self.loss_G_B +
                       self.loss_cycle_A * self.lambdaA +
                       self.loss_cycle_B * self.lambdaB +
                       self.loss_idt_A * self.lambdaA * self.lambdaIdt +
                       self.loss_idt_B * self.lambdaB * self.lambdaIdt)

        self.loss_G.backward()

    def compute_loss_D_basic(self, netD, real, fake):
        """Compute losses of corresponding discriminator"""
        pred_real = netD(real, parallel=opt.parallel)
        pred_fake = netD(fake.detach(), parallel=opt.parallel)
        loss_D_real = self.criterionGAN(
            pred_real,
            self.true_label.repeat(opt.batch_size, *pred_real[0][0].size()))
        loss_D_fake = self.criterionGAN(
            pred_fake,
            self.false_label.repeat(opt.batch_size, *pred_fake[0][0].size()))
        loss_D = (loss_D_real + loss_D_fake) / 2
        return loss_D

    def compute_loss_D_A(self):
        """Compute the loss of D_A
        Discriminator needs to get an image from the image pool
        """
        fake_B = self.fakeB_pool.query(self.fakeB)
        self.loss_D_A = self.compute_loss_D_basic(self.netD_A, self.realB,
                                                  fake_B)

    def compute_loss_D_B(self):
        """Compute the loss of D_B
        Discriminator needs to get an image from the image pool
        """
        fake_A = self.fakeA_pool.query(self.fakeA)
        self.loss_D_B = self.compute_loss_D_basic(self.netD_B, self.realA,
                                                  fake_A)

    def backward_D(self):
        """Compute the final loss of discriminators and gradients"""
        self.compute_loss_D_A()
        self.compute_loss_D_B()
        self.loss_D = self.loss_D_A + self.loss_D_B
        self.loss_D.backward()

    def forward(self,
                realA: torch.Tensor,
                realB: torch.Tensor,
                parallel=opt.parallel):
        """Forward images to the net"""
        self.realA = realA.to(self.device, non_blocking=True)
        self.realB = realB.to(self.device, non_blocking=True)
        #   X   <------->   Y
        self.fakeB = self.netG_A(self.realA,
                                 parallel=parallel)  # realA  --G_A--> fakeB
        self.recoA = self.netG_B(self.fakeB,
                                 parallel=parallel)  # recoA <--G_B--  fakeB
        self.fakeA = self.netG_B(self.realB,
                                 parallel=parallel)  # fakeA <--G_B--  realB
        self.recoB = self.netG_A(self.fakeA,
                                 parallel=parallel)  # fakeA  --G_A--> recoB

        # to preserve color composition                         #      X              Y
        self.idtA = self.netG_A(self.realB,
                                parallel=parallel)  # G_B--> idtB   realB ----⌍
        self.idtB = self.netG_B(
            self.realA, parallel=parallel)  #  ⌎---- realA   idtA <--G_A

    def backward(self):
        """Optimize the parameters"""
        self.optimizerG.zero_grad()
        self.backward_G()
        self.optimizerG.step()

        self.optimizerD.zero_grad()
        self.backward_D()
        self.optimizerD.step()

    def get_current_images(self) -> dict:
        """Returns lately generated images and input images with names"""
        return {
            'realA': self.realA,
            'realB': self.realB,
            'fakeA': self.fakeA,
            'fakeB': self.fakeB,
            'recoA': self.recoA,
            'recoB': self.recoB,
            'idtA': self.idtA,
            'idtB': self.idtB
        }

    def get_current_losses(self) -> dict:
        """Returns losses of this step with names"""
        loss_names = [
            'G', 'G_A', 'G_B', 'Cycle_A', 'Cycle_B', 'Idt_A', 'Idt_B', 'D',
            'D_A', 'D_B'
        ]
        losses = [
            self.loss_G, self.loss_G_A, self.loss_G_B, self.loss_cycle_A,
            self.loss_cycle_B, self.loss_idt_A, self.loss_idt_B, self.loss_D,
            self.loss_D_A, self.loss_D_B
        ]
        return {loss_name: loss for loss_name, loss in zip(loss_names, losses)}

    def update_learning_rate(self):
        """Update the learning rate at the end of each epoch"""
        for scheduler in self.schedulers:
            scheduler.step()
        lr = self.optimizerG.param_groups[0]['lr']
        print(f'learning rate = {lr:.7f}')

    def save(self, epoch, total_time):
        """

        :param epoch: epoch that has just finished. Training will begin from (epoch + 1).
        :param total_time:
        :return:
        """
        savefile = PATH('CycleGAN_ckpt.pth')
        torch.save(
            {
                'epoch': epoch,
                'total time': total_time,
                'G_A': self.netG_A.state_dict(),
                'G_B': self.netG_B.state_dict(),
                'D_A': self.netD_A.state_dict(),
                'D_B': self.netD_B.state_dict(),
                'optimizerG': self.optimizerG.state_dict(),
                'optimizerD': self.optimizerD.state_dict()
            }, savefile)

    def load(self, path='auto') -> tuple:
        checkpoint = torch.load(
            PATH('CycleGAN_ckpt.pth') if path == 'auto' else path)
        self.netG_A.load_state_dict(checkpoint['G_A'])
        self.netG_B.load_state_dict(checkpoint['G_B'])
        self.netD_A.load_state_dict(checkpoint['D_A'])
        self.netD_B.load_state_dict(checkpoint['D_B'])
        self.optimizerG.load_state_dict(checkpoint['optimizerG'])
        self.optimizerD.load_state_dict(checkpoint['optimizerD'])
        return checkpoint['epoch'], checkpoint['total_time']

    def train(self):
        self.netG_A.train()
        self.netG_B.train()
        self.netD_A.train()
        self.netD_B.train()

    def eval(self):
        self.netG_A.eval()
        self.netG_B.eval()
        self.netD_A.eval()
        self.netD_B.eval()
                                transform=train_trans, image_set='train')
train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True)
val_ds = VOCClassification('/path/to/VOC', transform=val_trans, image_set='val')
val_dl = DataLoader(val_ds, batch_size=8, shuffle=True, num_workers=2, drop_last=True)


# Model
if args.arc == 'vgg':
    model = vgg19(pretrained=True)
    num_ftrs = model.classifier[6].in_features
    model.classifier[6] = nn.Linear(num_ftrs, train_ds.CLASSES)
    model = DataParallelModel(model.cuda())
else:
    raise Exception("Architecture {} not found".format(args.arc))

criterion = DataParallelCriterion(nn.BCEWithLogitsLoss().cuda())
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0005)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 20, gamma=0.2)
best_pred = 0

# Load model
if args.resume:
    if not os.path.isfile(args.resume):
        raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
    checkpoint = torch.load(args.resume)
    args.start_epoch = checkpoint['epoch']
    model.module.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    best_pred = checkpoint['best_pred']

for epoch in range(args.epochs):
Example #13
0
class Trainer(CheckpointRunner):
    # noinspection PyAttributeOutsideInit
    def init_fn(self, shared_model=None, **kwargs):
        # Create auxiliary models
        self.init_auxiliary()
        if shared_model is not None:
            self.model = shared_model
        else:
            self.model = self.init_model()
            self.model = DataParallelModel(self.model.cuda(),
                                           device_ids=self.gpus)
            # self.model = torch.nn.DataParallel(self.model, device_ids=self.gpus).cuda()
        # Setup a joint optimizer for the 2 models
        self.optimizer = self.init_optimizer(self.options.optim.name)
        self.lr_scheduler = self.init_lr(self.options.optim.lr_scheduler)
        # Create loss functions
        self.criterion = self.init_loss_functions()
        self.criterion = DataParallelCriterion(self.criterion.cuda(),
                                               device_ids=self.gpus)
        # Create AverageMeters for losses
        self.losses = AverageMeter()
        # Evaluators
        # self.evaluators = [Evaluator(self.options, self.logger, self.summary_writer, shared_model=self.model)]
        self.dataset_size = None

    def init_auxiliary(self):
        pass

    def init_model(self):
        raise NotImplementedError("Your model is not found")

    def init_loss_functions(self):
        raise NotImplementedError("Your loss is not found")

    def init_optimizer(self, optim_name):
        if optim_name == "adam":
            optimizer = torch.optim.Adam(params=list(self.model.parameters()),
                                         lr=self.options.optim.lr,
                                         betas=(self.options.optim.adam_beta1,
                                                0.999),
                                         weight_decay=self.options.optim.wd)
        elif optim_name == "sgd":
            optimizer = torch.optim.SGD(
                params=list(self.model.parameters()),
                lr=self.options.optim.lr,
                momentum=self.options.optim.sgd_momentum,
                weight_decay=self.options.optim.wd)
        elif optim_name == "adam_gan":
            optimizer_d = torch.optim.Adam(
                params=list(self.model.module.D.parameters()),
                lr=self.options.optim.lr_d,
                betas=(self.options.optim.adam_beta1, 0.999),
                weight_decay=0)
            optimizer_g = torch.optim.Adam(
                params=list(self.model.module.G.parameters()),
                lr=self.options.optim.lr_g,
                betas=(self.options.optim.adam_beta1, 0.999),
                weight_decay=0)
            return {"optimizer_d": optimizer_d, "optimizer_g": optimizer_g}
        else:
            raise NotImplementedError("Your optimizer is not found")
        return optimizer

    def init_lr(self, lr_scheduler_name):
        if lr_scheduler_name == "multistep":
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                self.optimizer, self.options.optim.lr_step,
                self.options.optim.lr_factor)
        elif lr_scheduler_name == "exp":
            lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
                self.optimizer, gamma=self.options.optim.lr_gamma)
        elif lr_scheduler_name == "multistep_gan":
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                self.optimizer["optimizer_d"], self.options.optim.lr_step,
                self.options.optim.lr_factor)
        else:
            r_scheduler = None

        return lr_scheduler

    def models_dict(self):
        return {'model': self.model}

    def optimizers_dict(self):
        return {'optimizer': self.optimizer, 'lr_scheduler': self.lr_scheduler}

    def train_step(self, input_batch):
        # Grab data from the batch, predict with model
        out = self.model(input_batch)
        # compute loss
        loss, loss_summary = self.criterion(out, input_batch)
        self.losses.update(loss.detach().cpu().item())
        # Do backprop
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        # Pack output arguments to be used for visualization
        return recursive_detach(out), recursive_detach(loss_summary)

    def get_dataloader(self):
        data_loader = DataLoader(self.dataset,
                                 batch_size=self.options.train.batch_size *
                                 self.options.num_gpus,
                                 num_workers=self.options.num_workers,
                                 pin_memory=self.options.pin_memory,
                                 shuffle=self.options.train.shuffle)
        return data_loader

    def train(self):
        self.logger.info("Start Trainning.")
        # Create data loader at very begining
        train_data_loader = self.get_dataloader()
        self.dataset_size = len(train_data_loader)

        # Run training for num_epochs epochs
        for epoch in range(self.epoch_count, self.options.train.num_epochs):
            self.epoch_count += 1
            # Reset loss
            self.losses.reset()
            # Iterate over all batches in an epoch
            for step, batch in enumerate(train_data_loader):
                # Send input to GPU
                batch = {
                    k: v.cuda() if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()
                }
                # Run training step
                out = self.train_step(batch)
                self.step_count += 1
                # Tensorboard logging every summary_steps steps
                if self.step_count % self.options.train.summary_steps == 0:
                    self.train_summaries(batch, *out)
                # Save checkpoint every checkpoint_steps steps
                if self.step_count % self.options.train.checkpoint_steps == 0:
                    self.dump_checkpoint()
            if not self.options.model.name.endswith('gan'):
                self.dump_checkpoint()
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()

    def train_summaries(self, input_batch, out_summary, loss_summary):
        # Debug info for filenames
        self.logger.debug(input_batch["filename"])
        # Save results in Tensorboard
        self.tensorboard_step(loss_summary)
        # Save results to log
        self.log_step(loss_summary)

    def log_step(self, loss_summary):
        self.logger.info(
            "Epoch %03d, Step %06d/%06d, Time elapsed %s, Loss %.5f (AvgLoss %.5f)"
            % (self.epoch_count, self.step_count,
               self.options.train.num_epochs * len(self.dataset) //
               (self.options.train.batch_size * self.options.num_gpus),
               self.time_elapsed, self.losses.val, self.losses.avg))

    def tensorboard_step(self, loss_summary):
        for k, v in loss_summary.items():
            self.summary_writer.add_scalar(k, v, self.step_count)

    def init_with_pretrained_backbone(self):
        checkpoint_file = os.path.abspath(
            self.options.train.backbone_pretrained_model)
        pretrained_dict = torch.load(checkpoint_file)
        self.model.module.load_state_dict(pretrained_dict, strict=False)
        self.logger.info("Init with pre-trained backbone from %s." %
                         checkpoint_file)

    def test(self):
        self.model.eval()
        for evaluator in self.evaluators:
            evaluator.evaluate()
        self.model.train()
Example #14
0
def main(args):
    # initialization
    print("Input arguments:")
    for key, val in vars(args).items():
        print("{:16} {}".format(key, val))

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.method))

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    cudnn.benchmark = True

    adj_matrix = torch.tensor(
        [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
         [0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0],
         [0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
         [1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
        requires_grad=False)
    upper_part_list = [1, 2, 3, 4, 5, 6, 7, 11, 13, 14, 15]
    lower_part_list = [8, 9, 10, 12, 16, 17, 18, 19]
    weight = torch.FloatTensor([
        0.7602572, 0.94236198, 0.85644457, 1.04346266, 1.10627293, 0.80980162,
        0.95168713, 0.8403769, 1.05798412, 0.85746254, 1.01274366, 1.05854692,
        1.03430773, 0.84867818, 0.88027721, 0.87580925, 0.98747462, 0.9876475,
        1.00016535, 1.00108882
    ])

    # conduct seg network
    seg_model = get_model(num_classes=args.num_classes,
                          adj_matrix=adj_matrix,
                          upper_part_list=upper_part_list,
                          lower_part_list=lower_part_list)

    saved_state_dict = torch.load(args.restore_from)
    new_params = seg_model.state_dict().copy()

    # if args.init:
    #    for i in saved_state_dict:
    #        i_parts = i.split('.')
    #        if not i_parts[0] == 'fc':
    #            new_params['encoder.' + '.'.join(i_parts[:])] = saved_state_dict[i]
    #            #new_params[i_parts[:]] = saved_state_dict[i]
    #    seg_model.load_state_dict(new_params)
    #    print('loading params w/o fc')
    # else:
    #    seg_model.load_state_dict(saved_state_dict)
    #    print('loading params all')

    for i in saved_state_dict:
        i_parts = i.split('.')
        if not i_parts[0] == 'fc':
            new_params['encoder.' + '.'.join(i_parts[:])] = saved_state_dict[i]
    seg_model.load_state_dict(new_params)
    print('loading params w/o fc')

    # seg_model.load_state_dict(saved_state_dict)
    # print('loading params all')

    model = DataParallelModel(seg_model)
    model.float()
    model.cuda()

    # define dataloader
    train_loader = data.DataLoader(DatasetGenerator(root=args.root,
                                                    list_path=args.lst,
                                                    crop_size=args.crop_size,
                                                    training=True),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=4,
                                   pin_memory=True)
    val_loader = data.DataLoader(DatasetGenerator(root=args.val_root,
                                                  list_path=args.val_lst,
                                                  crop_size=args.crop_size,
                                                  training=False),
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=4,
                                 pin_memory=True)

    # define criterion & optimizer
    criterion = ABRLovaszLoss(adj_matrix=adj_matrix,
                              ignore_index=args.ignore_label,
                              only_present=True,
                              upper_part_list=upper_part_list,
                              lower_part_list=lower_part_list,
                              cls_p=args.num_classes,
                              cls_h=args.hbody_cls,
                              cls_f=args.fbody_cls,
                              weight=weight)
    criterion = DataParallelCriterion(criterion).cuda()

    optimizer = optim.SGD(
        [{
            'params': filter(lambda p: p.requires_grad,
                             seg_model.parameters()),
            'lr': args.learning_rate
        }],
        lr=args.learning_rate,
        momentum=0.9,
        weight_decay=5e-4)

    # key points
    best_val_mIoU = 0
    best_val_pixAcc = 0
    start = time.time()

    for epoch in range(0, args.epochs):
        print('\n{} | {}'.format(epoch, args.epochs - 1))
        # training
        _ = train(model, train_loader, epoch, criterion, optimizer, writer)
        # validation
        if epoch % 10 == 0 or epoch > args.epochs - 5:
            val_pixacc, val_miou = validation(model, val_loader, epoch, writer)
            # save model
            if val_pixacc > best_val_pixAcc:
                best_val_pixAcc = val_pixacc
            if val_miou > best_val_mIoU:
                best_val_mIoU = val_miou
                model_dir = os.path.join(args.snapshot_dir,
                                         args.method + '_miou.pth')
                torch.save(seg_model.state_dict(), model_dir)
                print('Model saved to %s' % model_dir)

    os.rename(
        model_dir,
        os.path.join(args.snapshot_dir,
                     args.method + '_miou' + str(best_val_mIoU) + '.pth'))
    print('Complete using', time.time() - start, 'seconds')
    print('Best pixAcc: {} | Best mIoU: {}'.format(best_val_pixAcc,
                                                   best_val_mIoU))