Beispiel #1
0
def main():
    if not cfg.TRAIN.MULTI_GPU:
        torch.cuda.set_device(cfg.TRAIN.GPU_ID[0])

    i_tb = 0
    # loading data
    src_loader, tgt_loader, restore_transform = load_dataset()
    data_encoder = DataEncoder()

    
    ext_model = None
    dc_model = None
    obc_model = None
    # initialize models
    if cfg.TRAIN.COM_EXP == 5: # Full model
        ext_model, dc_model, obc_model, cur_epoch = init_model(cfg.TRAIN.NET)
    elif cfg.TRAIN.COM_EXP == 6: # FCN + SSD + OBC
        ext_model, _, obc_model, cur_epoch = init_model(cfg.TRAIN.NET)
    elif cfg.TRAIN.COM_EXP == 4: # FCN + SSD + DC
        ext_model, dc_model, _, cur_epoch = init_model(cfg.TRAIN.NET)
    elif cfg.TRAIN.COM_EXP == 3: # FCN + SSD
        ext_model, _, __, cur_epoch = init_model(cfg.TRAIN.NET)
    
    # set criterion and optimizer, training
    if ext_model is not None: 
        weight = torch.ones(cfg.DATA.NUM_CLASSES)
        weight[cfg.DATA.NUM_CLASSES - 1] = 0
        spvsd_cri = CrossEntropyLoss2d(cfg.TRAIN.LABEL_WEIGHT).cuda() # traditional Loss for the FCN-8s
        unspvsd_cri = CrossEntropyLoss2d(cfg.TRAIN.LABEL_WEIGHT).cuda() # traditional Loss for the FCN-8s
        det_cri = MultiBoxLoss()
        # the ext_opt will be set in the train_net.py, because the ssd learning rate is stepsise        

    if dc_model is not None:
        dc_cri = CrossEntropyLoss2d().cuda()    
        dc_invs_cri = CrossEntropyLoss2d().cuda()
        dc_opt = optim.Adam(dc_model.parameters(), lr=cfg.TRAIN.DC_LR, betas=(0.5, 0.999))        

    if obc_model is not None:
        obc_cri = CrossEntropyLoss().cuda()
        obc_invs_cri = CrossEntropyLoss().cuda()
        obc_opt = optim.Adam(obc_model.parameters(), lr=cfg.TRAIN.OBC_LR, betas=(0.5, 0.999))

    if cfg.TRAIN.COM_EXP == 6:
        train_adversarial(cur_epoch, i_tb, data_encoder, src_loader, tgt_loader, restore_transform, 
                        ext_model, spvsd_cri, unspvsd_cri, det_cri, 
                        obc_model=obc_model, obc_cri=obc_cri, obc_invs_cri=obc_invs_cri, obc_opt=obc_opt)
        

    if cfg.TRAIN.COM_EXP == 5:
        train_adversarial(cur_epoch, i_tb, data_encoder, src_loader, tgt_loader, restore_transform, 
                        ext_model, spvsd_cri, unspvsd_cri, det_cri, 
                        dc_model=dc_model,  dc_cri=dc_cri, dc_invs_cri=dc_invs_cri, dc_opt=dc_opt, 
                        obc_model=obc_model, obc_cri=obc_cri, obc_invs_cri=obc_invs_cri, obc_opt=obc_opt)
    if cfg.TRAIN.COM_EXP == 4:
        train_adversarial(cur_epoch, i_tb, data_encoder, src_loader, tgt_loader, restore_transform, 
                        ext_model, spvsd_cri, unspvsd_cri, det_cri, 
                        dc_model=dc_model,  dc_cri=dc_cri, dc_invs_cri=dc_invs_cri, dc_opt=dc_opt)
    if cfg.TRAIN.COM_EXP == 3:
        train_adversarial(cur_epoch, i_tb, data_encoder, src_loader, tgt_loader, restore_transform, 
                        ext_model, spvsd_cri, unspvsd_cri, det_cri)
Beispiel #2
0
def main():
    net = Net(4, num_classes=RS.num_classes + 1)  #, pretrained=True
    if args['gpu']: net = net.cuda()

    train_set = RS.PolSAR(mode='train',
                          random_flip=True,
                          crop_size=args['train_crop_size'])
    val_set = RS.PolSAR(mode='val', random_flip=False, crop_size=False)

    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=4,
                              shuffle=True)
    val_loader = DataLoader(val_set,
                            batch_size=args['val_batch_size'],
                            num_workers=4,
                            shuffle=False)

    criterion = CrossEntropyLoss2d(ignore_index=0)
    #criterion = FocalLoss2d(gamma=2.0, ignore_index=0)
    if args['gpu']: criterion = criterion.cuda()
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()),
                          lr=args['lr'],
                          weight_decay=args['weight_decay'],
                          momentum=args['momentum'],
                          nesterov=True)
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          1,
                                          gamma=0.95,
                                          last_epoch=-1)

    train(train_loader, net, criterion, optimizer, scheduler, args, val_loader)
    writer.close()
    print('Training finished.')
Beispiel #3
0
def main():
    net = FCN8ResNet(num_classes=num_classes).cuda()
    if len(train_args['snapshot']) == 0:
        curr_epoch = 0
    else:
        print 'training resumes from ' + train_args['snapshot']
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1])
        train_record['best_val_loss'] = float(split_snapshot[3])
        train_record['corr_mean_iu'] = float(split_snapshot[6])
        train_record['corr_epoch'] = curr_epoch

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    train_simul_transform = simul_transforms.Compose([
        simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)),
        simul_transforms.RandomCrop(train_args['input_size']),
        simul_transforms.RandomHorizontallyFlip()
    ])
    val_simul_transform = simul_transforms.Compose([
        simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)),
        simul_transforms.CenterCrop(train_args['input_size'])
    ])
    img_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = standard_transforms.Compose([
        expanded_transforms.MaskToTensor(),
        expanded_transforms.ChangeLabel(ignored_label, num_classes - 1)
    ])
    restore_transform = standard_transforms.Compose([
        expanded_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    train_set = CityScapes('train',
                           simul_transform=train_simul_transform,
                           transform=img_transform,
                           target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=train_args['batch_size'],
                              num_workers=16,
                              shuffle=True)
    val_set = CityScapes('val',
                         simul_transform=val_simul_transform,
                         transform=img_transform,
                         target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=val_args['batch_size'],
                            num_workers=16,
                            shuffle=False)

    weight = torch.ones(num_classes)
    weight[num_classes - 1] = 0
    criterion = CrossEntropyLoss2d(weight).cuda()

    # don't use weight_decay for bias
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and 'fconv' in name
        ],
        'lr':
        2 * train_args['new_lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and 'fconv' in name
        ],
        'lr':
        train_args['new_lr'],
        'weight_decay':
        train_args['weight_decay']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and 'fconv' not in name
        ],
        'lr':
        2 * train_args['pretrained_lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and 'fconv' not in name
        ],
        'lr':
        train_args['pretrained_lr'],
        'weight_decay':
        train_args['weight_decay']
    }],
                          momentum=0.9,
                          nesterov=True)

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['new_lr']
        optimizer.param_groups[1]['lr'] = train_args['new_lr']
        optimizer.param_groups[2]['lr'] = 2 * train_args['pretrained_lr']
        optimizer.param_groups[3]['lr'] = train_args['pretrained_lr']

    if not os.path.exists(ckpt_path):
        os.mkdir(ckpt_path)
    if not os.path.exists(os.path.join(ckpt_path, exp_name)):
        os.mkdir(os.path.join(ckpt_path, exp_name))

    for epoch in range(curr_epoch, train_args['epoch_num']):
        train(train_loader, net, criterion, optimizer, epoch)
        validate(val_loader, net, criterion, optimizer, epoch,
                 restore_transform)
Beispiel #4
0
    gamma=config["optimizer"]["lr_step"],
)

weights = torch.ones(config["task1_classes"]).cuda()
if config["task1_weight"] < 1:
    print("Roads are weighted.")
    weights[0] = 1 - config["task1_weight"]
    weights[1] = config["task1_weight"]

weights_angles = torch.ones(config["task2_classes"]).cuda()
if config["task2_weight"] < 1:
    print("Road angles are weighted.")
    weights_angles[-1] = config["task2_weight"]

angle_loss = CrossEntropyLoss2d(weight=weights_angles,
                                size_average=True,
                                ignore_index=255,
                                reduce=True).cuda()
road_loss = mIoULoss(weight=weights,
                     size_average=True,
                     n_classes=config["task1_classes"]).cuda()


def train(epoch):
    train_loss_iou = 0
    train_loss_vec = 0
    model.train()
    optimizer.zero_grad()
    hist = np.zeros((config["task1_classes"], config["task1_classes"]))
    hist_angles = np.zeros((config["task2_classes"], config["task2_classes"]))
    crop_size = config["train_dataset"][args.dataset]["crop_size"]
    for i, data in enumerate(train_loader, 0):
Beispiel #5
0
def test_func(args):
    """
     main function for testing
     param args: global arguments
     return: None
    """
    print(args)
    global network_type

    if args.cuda:
        print("=====> use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "no GPU found or wrong gpu id, please run without --cuda")

    args.seed = random.randint(1, 10000)
    print("Random Seed: ", args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print('=====> checking if processed cached_data_file exists')
    if not os.path.isfile(args.inform_data_file):
        dataCollect = CityscapesTrainInform(
            args.data_dir,
            args.classes,
            train_set_file=args.dataset_list,
            inform_data_file=args.inform_data_file
        )  #collect mean std, weigth_class information
        data = dataCollect.collectDataAndSave()
        if data is None:
            print("error while pickling data, please check")
            exit(-1)
    else:
        data = pickle.load(open(args.inform_data_file, "rb"))
    M = args.M
    N = args.N

    model = CGNet.Context_Guided_Network(classes=args.classes, M=M, N=N)
    network_type = "CGNet"
    print("Arch:  CGNet")
    # define optimization criteria
    weight = torch.from_numpy(
        data['classWeights'])  # convert the numpy array to torch
    if args.cuda:
        weight = weight.cuda()
    criteria = CrossEntropyLoss2d(weight)  #weight

    if args.cuda:
        model = model.cuda()  # using GPU for inference
        criteria = criteria.cuda()
        cudnn.benchmark = True

    print('Dataset statistics')
    print('mean and std: ', data['mean'], data['std'])
    print('classWeights: ', data['classWeights'])

    if args.save_seg_dir:
        if not os.path.exists(args.save_seg_dir):
            os.makedirs(args.save_seg_dir)

    # validation set
    testLoader = torch.utils.data.DataLoader(CityscapesTestDataSet(
        args.data_dir, args.test_data_list, mean=data['mean']),
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=====> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            #model.load_state_dict(checkpoint['model'])
            model.load_state_dict(convert_state_dict(checkpoint['model']))
        else:
            print("=====> no checkpoint found at '{}'".format(args.resume))

    print("=====> beginning testing")
    print("test set length: ", len(testLoader))
    test(args, testLoader, model)
Beispiel #6
0
def main(args):
    # args = get_arguments()
    
    if not os.path.exists(args.result):
        os.makedirs(args.result)
        
    # create network
    model = get_model(name=args.generatormodel, num_classes = args.num_classes)
    
    if args.pretrained_model != None:
            args.restore_from = pretrianed_models_dict[args.pretrainned_model]

            
    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
        
    model.load_state_dict(saved_state_dict)
    
    """
    if args.model == 'DeepLab':
        model = TwinsAdvNet_DL(num_classes = args.num_classes)
        if args.pretrained_model != None:
            args.restore_from = pretrained_models_dict[args.pretrained_model]
            
        if args.restore_from[:4] == 'http' :
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
            
    model.load_state_dict(saved_state_dict)
    """
    
    """
    # load nets into gpu
    if args.num_gpus > 1:
        model = torch.nn.DataParallel(model, device_ids=range(args.num_gpus))
    model.cuda()
    """
    # crit = nn.NLLLoss2d(ignore_index=-1)  # ade20k
    #crit = nn.CrossEntropyLoss(ignore_index = -1)
    #crit = CrossEntropyLoss2d()
    crit = CrossEntropyLoss2d(ignore_index = -1)
    
    # interp = nn.Upsample(size=(384, 384), mode='bilinear')
    interp = nn.Upsample(size=(args.segSize, args.segSize), mode='bilinear')
    
    #train_dataset = MITSceneParsingDataset(args.list_train, args, is_train=1)
    val_dataset = MITSceneParsingDataset(args.list_val, args, max_sample=args.num_val, is_train=0)
    
    #val_dataset_size = len(val_dataset)
    #args.epoch_iters = int(train_dataset_size / (args.batch_size * args.num_gpus))
    #print('train_dataset_size = {} | 1 Epoch = {} iters'.format(train_dataset_size, args.epoch_iters))
  
    val_loader = data.DataLoader(val_dataset, 
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=int(args.workers),
                                  pin_memory=True,
                                  drop_last=True)
    
    nets = (model, interp, crit)
    
    """
    for model in nets:
        # load nets into gpu
        if args.num_gpus > 1:
            model = torch.nn.DataParallel(model, device_ids=range(args.num_gpus))
        model.cuda()
    """
    for model in nets:
        model.cuda()
    # Main loop
    evaluate(nets, val_loader, args)
    
    print('Evaluation Done!')
Beispiel #7
0
def train(args, model, enc=False):
    best_acc = 0

    #TODO: calculate weights by processing dataset histogram (now its being set by hand from the torch values)
    #create a loder to run all images and calculate histogram of labels, then create weight array using class balancing

    weight = torch.ones(NUM_CLASSES)
    if (enc):
        weight[0] = 2.3653597831726
        weight[1] = 4.4237880706787
        weight[2] = 2.9691488742828
        weight[3] = 5.3442072868347
        weight[4] = 5.2983593940735
        weight[5] = 5.2275490760803
        weight[6] = 5.4394111633301
        weight[7] = 5.3659925460815
        weight[8] = 3.4170460700989
        weight[9] = 5.2414722442627
        weight[10] = 4.7376127243042
        weight[11] = 5.2286224365234
        weight[12] = 5.455126285553
        weight[13] = 4.3019247055054
        weight[14] = 5.4264230728149
        weight[15] = 5.4331531524658
        weight[16] = 5.433765411377
        weight[17] = 5.4631009101868
        weight[18] = 5.3947434425354
    else:
        weight[0] = 2.8149201869965
        weight[1] = 6.9850029945374
        weight[2] = 3.7890393733978
        weight[3] = 9.9428062438965
        weight[4] = 9.7702074050903
        weight[5] = 9.5110931396484
        weight[6] = 10.311357498169
        weight[7] = 10.026463508606
        weight[8] = 4.6323022842407
        weight[9] = 9.5608062744141
        weight[10] = 7.8698215484619
        weight[11] = 9.5168733596802
        weight[12] = 10.373730659485
        weight[13] = 6.6616044044495
        weight[14] = 10.260489463806
        weight[15] = 10.287888526917
        weight[16] = 10.289801597595
        weight[17] = 10.405355453491
        weight[18] = 10.138095855713

    weight[19] = 0

    assert os.path.exists(
        args.datadir), "Error: datadir (dataset directory) could not be loaded"

    co_transform = MyCoTransform(enc, augment=True, height=args.height)  #512)
    co_transform_val = MyCoTransform(enc, augment=False,
                                     height=args.height)  #512)
    dataset_train = cityscapes(args.datadir, co_transform, 'train')
    dataset_val = cityscapes(args.datadir, co_transform_val, 'val')

    loader = DataLoader(dataset_train,
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=True)
    loader_val = DataLoader(dataset_val,
                            num_workers=args.num_workers,
                            batch_size=args.batch_size,
                            shuffle=False)

    if args.cuda:
        weight = weight.cuda()
    criterion = CrossEntropyLoss2d(weight)
    print(type(criterion))

    savedir = f'../save/{args.savedir}'

    if (enc):
        automated_log_path = savedir + "/automated_log_encoder.txt"
        modeltxtpath = savedir + "/model_encoder.txt"
    else:
        automated_log_path = savedir + "/automated_log.txt"
        modeltxtpath = savedir + "/model.txt"

    if (not os.path.exists(automated_log_path)
        ):  #dont add first line if it exists
        with open(automated_log_path, "a") as myfile:
            myfile.write(
                "Epoch\t\tTrain-loss\t\tTest-loss\t\tTrain-IoU\t\tTest-IoU\t\tlearningRate"
            )

    with open(modeltxtpath, "w") as myfile:
        myfile.write(str(model))

    #TODO: reduce memory in first gpu: https://discuss.pytorch.org/t/multi-gpu-training-memory-usage-in-balance/4163/4

#https://github.com/pytorch/pytorch/issues/1893

#optimizer = Adam(model.parameters(), 5e-4, (0.9, 0.999),  eps=1e-08, weight_decay=2e-4)     ## scheduler 1
    optimizer = Adam(model.parameters(),
                     5e-4, (0.9, 0.999),
                     eps=1e-08,
                     weight_decay=1e-4)  ## scheduler 2

    start_epoch = 1
    if args.resume:
        #Must load weights, optimizer, epoch and best value.
        if enc:
            filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar'
        else:
            filenameCheckpoint = savedir + '/checkpoint.pth.tar'

        assert os.path.exists(
            filenameCheckpoint
        ), "Error: resume option was used but checkpoint was not found in folder"
        checkpoint = torch.load(filenameCheckpoint)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best_acc = checkpoint['best_acc']
        print("=> Loaded checkpoint at epoch {})".format(checkpoint['epoch']))

    #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5) # set up scheduler     ## scheduler 1
    lambda1 = lambda epoch: pow(
        (1 - ((epoch - 1) / args.num_epochs)), 0.9)  ## scheduler 2
    scheduler = lr_scheduler.LambdaLR(optimizer,
                                      lr_lambda=lambda1)  ## scheduler 2

    if args.visualize and args.steps_plot > 0:
        board = Dashboard(args.port)

    for epoch in range(start_epoch, args.num_epochs + 1):
        print("----- TRAINING - EPOCH", epoch, "-----")

        scheduler.step(epoch)  ## scheduler 2

        epoch_loss = []
        time_train = []

        doIouTrain = args.iouTrain
        doIouVal = args.iouVal

        if (doIouTrain):
            iouEvalTrain = iouEval(NUM_CLASSES)

        usedLr = 0
        for param_group in optimizer.param_groups:
            print("LEARNING RATE: ", param_group['lr'])
            usedLr = float(param_group['lr'])

        model.train()
        for step, (images, labels) in enumerate(loader):

            start_time = time.time()

            imgs_batch = images.shape[0]
            if imgs_batch != args.batch_size:
                break

            if args.cuda:
                inputs = images.cuda()
                targets = labels.cuda()

            outputs = model(inputs, only_encode=enc)

            #print("targets", np.unique(targets[:, 0].cpu().data.numpy()))

            optimizer.zero_grad()
            loss = criterion(outputs, targets[:, 0])

            loss.backward()
            optimizer.step()

            epoch_loss.append(loss.item())
            time_train.append(time.time() - start_time)

            if (doIouTrain):
                #start_time_iou = time.time()
                iouEvalTrain.addBatch(
                    outputs.max(1)[1].unsqueeze(1).data, targets.data)
                #print ("Time to add confusion matrix: ", time.time() - start_time_iou)

            #print(outputs.size())
            if args.visualize and args.steps_plot > 0 and step % args.steps_plot == 0:
                start_time_plot = time.time()
                image = inputs[0].cpu().data
                #image[0] = image[0] * .229 + .485
                #image[1] = image[1] * .224 + .456
                #image[2] = image[2] * .225 + .406
                #print("output", np.unique(outputs[0].cpu().max(0)[1].data.numpy()))
                board.image(image, f'input (epoch: {epoch}, step: {step})')
                if isinstance(outputs, list):  #merge gpu tensors
                    board.image(
                        color_transform(
                            outputs[0][0].cpu().max(0)[1].data.unsqueeze(0)),
                        f'output (epoch: {epoch}, step: {step})')
                else:
                    board.image(
                        color_transform(
                            outputs[0].cpu().max(0)[1].data.unsqueeze(0)),
                        f'output (epoch: {epoch}, step: {step})')
                board.image(color_transform(targets[0].cpu().data),
                            f'target (epoch: {epoch}, step: {step})')
                print("Time to paint images: ", time.time() - start_time_plot)
            if args.steps_loss > 0 and step % args.steps_loss == 0:
                average = sum(epoch_loss) / len(epoch_loss)
                print(
                    f'loss: {average:0.4} (epoch: {epoch}, step: {step})',
                    "// Avg time/img: %.4f s" %
                    (sum(time_train) / len(time_train) / args.batch_size))

        average_epoch_loss_train = sum(epoch_loss) / len(epoch_loss)

        iouTrain = 0
        if (doIouTrain):
            iouTrain, iou_classes = iouEvalTrain.getIoU()
            iouStr = getColorEntry(iouTrain) + '{:0.2f}'.format(
                iouTrain * 100) + '\033[0m'
            print("EPOCH IoU on TRAIN set: ", iouStr, "%")

        #Validate on 500 val images after each epoch of training
        print("----- VALIDATING - EPOCH", epoch, "-----")
        model.eval()
        epoch_loss_val = []
        time_val = []

        if (doIouVal):
            iouEvalVal = iouEval(NUM_CLASSES)

        for step, (images, labels) in enumerate(loader_val):
            start_time = time.time()

            imgs_batch = images.shape[0]
            if imgs_batch != args.batch_size:
                break

            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()

            with torch.no_grad():
                inputs = Variable(images)
                targets = Variable(labels)

            outputs = model(inputs, only_encode=enc)

            loss = criterion(outputs, targets[:, 0])
            epoch_loss_val.append(loss.item())
            time_val.append(time.time() - start_time)

            #Add batch to calculate TP, FP and FN for iou estimation
            if (doIouVal):
                #start_time_iou = time.time()
                iouEvalVal.addBatch(
                    outputs.max(1)[1].unsqueeze(1).data, targets.data)
                #print ("Time to add confusion matrix: ", time.time() - start_time_iou)

            if args.visualize and args.steps_plot > 0 and step % args.steps_plot == 0:
                start_time_plot = time.time()
                image = inputs[0].cpu().data
                board.image(image, f'VAL input (epoch: {epoch}, step: {step})')
                if isinstance(outputs, list):  #merge gpu tensors
                    board.image(
                        color_transform(
                            outputs[0][0].cpu().max(0)[1].data.unsqueeze(0)),
                        f'VAL output (epoch: {epoch}, step: {step})')
                else:
                    board.image(
                        color_transform(
                            outputs[0].cpu().max(0)[1].data.unsqueeze(0)),
                        f'VAL output (epoch: {epoch}, step: {step})')
                board.image(color_transform(targets[0].cpu().data),
                            f'VAL target (epoch: {epoch}, step: {step})')
                print("Time to paint images: ", time.time() - start_time_plot)
            if args.steps_loss > 0 and step % args.steps_loss == 0:
                average = sum(epoch_loss_val) / len(epoch_loss_val)
                print(
                    f'VAL loss: {average:0.4} (epoch: {epoch}, step: {step})',
                    "// Avg time/img: %.4f s" %
                    (sum(time_val) / len(time_val) / args.batch_size))

        average_epoch_loss_val = sum(epoch_loss_val) / len(epoch_loss_val)
        #scheduler.step(average_epoch_loss_val, epoch)  ## scheduler 1   # update lr if needed

        iouVal = 0
        if (doIouVal):
            iouVal, iou_classes = iouEvalVal.getIoU()
            iouStr = getColorEntry(iouVal) + '{:0.2f}'.format(
                iouVal * 100) + '\033[0m'
            print("EPOCH IoU on VAL set: ", iouStr, "%")

        # remember best valIoU and save checkpoint
        if iouVal == 0:
            current_acc = -average_epoch_loss_val
        else:
            current_acc = iouVal
        is_best = current_acc > best_acc
        best_acc = max(current_acc, best_acc)
        if enc:
            filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar'
            filenameBest = savedir + '/model_best_enc.pth.tar'
        else:
            filenameCheckpoint = savedir + '/checkpoint.pth.tar'
            filenameBest = savedir + '/model_best.pth.tar'
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': str(model),
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            }, is_best, filenameCheckpoint, filenameBest)

        #SAVE MODEL AFTER EPOCH
        if (enc):
            filename = f'{savedir}/model_encoder-{epoch:03}.pth'
            filenamebest = f'{savedir}/model_encoder_best.pth'
        else:
            filename = f'{savedir}/model-{epoch:03}.pth'
            filenamebest = f'{savedir}/model_best.pth'
        if args.epochs_save > 0 and step > 0 and step % args.epochs_save == 0:
            torch.save(model.state_dict(), filename)
            print(f'save: {filename} (epoch: {epoch})')
        if (is_best):
            torch.save(model.state_dict(), filenamebest)
            print(f'save: {filenamebest} (epoch: {epoch})')
            if (not enc):
                with open(savedir + "/best.txt", "w") as myfile:
                    myfile.write("Best epoch is %d, with Val-IoU= %.4f" %
                                 (epoch, iouVal))
            else:
                with open(savedir + "/best_encoder.txt", "w") as myfile:
                    myfile.write("Best epoch is %d, with Val-IoU= %.4f" %
                                 (epoch, iouVal))

        #SAVE TO FILE A ROW WITH THE EPOCH RESULT (train loss, val loss, train IoU, val IoU)
        #Epoch		Train-loss		Test-loss	Train-IoU	Test-IoU		learningRate
        with open(automated_log_path, "a") as myfile:
            myfile.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.8f" %
                         (epoch, average_epoch_loss_train,
                          average_epoch_loss_val, iouTrain, iouVal, usedLr))

    return (model)  #return model (convenience for encoder-decoder training)
Beispiel #8
0
def train_model(args):
    """
    Main function for training 
    Args:
       args: global arguments
    """
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    print("=====> Check if processed data file exists or not")
    if not os.path.isfile(args.inform_data_file):
        print("%s is not found" % (args.inform_data_file))
        dataCollect = CityscapesTrainInform(
            args.data_dir,
            args.classes,
            train_set_file=args.dataset_list,
            inform_data_file=args.inform_data_file
        )  #collect mean std, weigth_class information
        datas = dataCollect.collectDataAndSave()
        if datas is None:
            print('Error while pickling data. Please check.')
            exit(-1)
    else:
        datas = pickle.load(open(args.inform_data_file, "rb"))

    print(args)
    global network_type

    if args.cuda:
        print("=====> Use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    args.seed = random.randint(1, 10000)
    print("Random Seed: ", args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    cudnn.enabled = True
    M = args.M
    N = args.N
    # load the model
    print('=====> Building network')
    model = CGNet.Context_Guided_Network(classes=args.classes, M=M, N=N)
    network_type = "CGNet"
    print("current architeture:  CGNet")
    args.savedir = args.savedir + network_type + "_M" + str(M) + 'N' + str(
        N) + '/'

    # create the directory of checkpoint if not exist
    if not os.path.exists(args.savedir):
        os.makedirs(args.savedir)

    print('=====> Computing network parameters')
    total_paramters = netParams(model)
    print('Total network parameters: ' + str(total_paramters))

    print("data['classWeights']: ", datas['classWeights'])
    weight = torch.from_numpy(datas['classWeights'])
    criteria = CrossEntropyLoss2d(weight)
    criteria = criteria.cuda()
    print('=====> Dataset statistics')
    print('mean and std: ', datas['mean'], datas['std'])

    if args.cuda:
        if torch.cuda.device_count() > 1:
            print("torch.cuda.device_count()=", torch.cuda.device_count())
            model = torch.nn.DataParallel(
                model).cuda()  #multi-card data parallel
        else:
            print("single GPU for training")
            model = model.cuda()  #single card

    start_epoch = 0

    #DataLoader
    trainLoader = data.DataLoader(CityscapesDataSet(args.data_dir,
                                                    args.train_data_list,
                                                    crop_size=input_size,
                                                    scale=args.random_scale,
                                                    mirror=args.random_mirror,
                                                    mean=datas['mean']),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True,
                                  drop_last=True)

    valLoader = data.DataLoader(CityscapesValDataSet(args.data_dir,
                                                     args.val_data_list,
                                                     f_scale=1,
                                                     mean=datas['mean']),
                                batch_size=1,
                                shuffle=True,
                                num_workers=args.num_workers,
                                pin_memory=True,
                                drop_last=True)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=====> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['model'])
            print("=====> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=====> no checkpoint found at '{}'".format(args.resume))

    model.train()
    cudnn.benchmark = True

    logFileLoc = args.savedir + args.logFile
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
    else:
        logger = open(logFileLoc, 'w')
        logger.write("Parameters: %s" % (str(total_paramters)))
        logger.write(
            "\n%s\t\t%s\t\t%s\t\t%s\t\t%s\t\t" %
            ('Epoch', 'Loss(Tr)', 'Loss(val)', 'mIOU (tr)', 'mIOU (val)'))
    logger.flush()
    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr, (0.9, 0.999),
                                 eps=1e-08,
                                 weight_decay=5e-4)

    print('=====> beginning training')
    for epoch in range(start_epoch, args.max_epochs):
        lossTr, per_class_iu_tr, mIOU_tr, lr = train(args, trainLoader, model,
                                                     criteria, optimizer,
                                                     epoch)
        # evaluate on validation set
        if epoch % 50 == 0:
            mIOU_val, per_class_iu = val(args, valLoader, model, criteria)
            logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.7f" %
                         (epoch, lossTr, mIOU_tr, mIOU_val, lr))
            logger.flush()
            print("Epoch : " + str(epoch) + ' Details')
            print(
                "\nEpoch No.: %d\tTrain Loss = %.4f\t mIOU(tr) = %.4f\t mIOU(val) = %.4f\t lr= %.6f"
                % (epoch, lossTr, mIOU_tr, mIOU_val, lr))
        else:
            logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.7f" %
                         (epoch, lossTr, mIOU_tr, lr))
            logger.flush()
            print("Epoch : " + str(epoch) + ' Details')
            print(
                "\nEpoch No.: %d\tTrain Loss = %.4f\t mIOU(tr) = %.4f\t lr= %.6f"
                % (epoch, lossTr, mIOU_tr, lr))
        #save the model
        model_file_name = args.savedir + '/model_' + str(epoch + 1) + '.pth'
        state = {"epoch": epoch + 1, "model": model.state_dict()}
        torch.save(state, model_file_name)
    logger.close()
Beispiel #9
0
def train_model(args):
    """
    args:
       args: global arguments
    """
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    print("=====> checking if inform_data_file exists")
    if not os.path.isfile(args.inform_data_file):
        print("%s is not found" % (args.inform_data_file))
        dataCollect = CityscapesTrainInform(
            args.data_dir,
            args.classes,
            train_set_file=args.dataset_list,
            inform_data_file=args.inform_data_file
        )  #collect mean std, weigth_class information
        datas = dataCollect.collectDataAndSave()
        if datas is None:
            print("error while pickling data. Please check.")
            exit(-1)
    else:
        print("find file: ", str(args.inform_data_file))
        datas = pickle.load(open(args.inform_data_file, "rb"))

    print(args)
    global network_type

    if args.cuda:
        print("=====> use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    #args.seed = random.randint(1, 10000)
    args.seed = 9830

    print("====> Random Seed: ", args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    cudnn.enabled = True

    model = MobileNetV3(model_mode="SMALL", num_classes=args.classes)

    network_type = "MobileNetV3"
    print("=====> current architeture:  MobileNetV3")

    print("=====> computing network parameters")
    total_paramters = netParams(model)
    print("the number of parameters: " + str(total_paramters))

    print("data['classWeights']: ", datas['classWeights'])
    print('=====> Dataset statistics')
    print('mean and std: ', datas['mean'], datas['std'])

    # define optimization criteria
    weight = torch.from_numpy(datas['classWeights'])
    criteria = CrossEntropyLoss2d(weight)

    if args.cuda:
        criteria = criteria.cuda()
        if torch.cuda.device_count() > 1:
            print("torch.cuda.device_count()=", torch.cuda.device_count())
            args.gpu_nums = torch.cuda.device_count()
            model = torch.nn.DataParallel(
                model).cuda()  #multi-card data parallel
        else:
            print("single GPU for training")
            model = model.cuda()  #1-card data parallel

    args.savedir = (args.savedir + args.dataset + '/' + network_type + 'bs' +
                    str(args.batch_size) + 'gpu' + str(args.gpu_nums) + "_" +
                    str(args.train_type) + '/')

    if not os.path.exists(args.savedir):
        os.makedirs(args.savedir)

    train_transform = transforms.Compose([transforms.ToTensor()])

    trainLoader = data.DataLoader(CityscapesDataSet(args.data_dir,
                                                    args.train_data_list,
                                                    crop_size=input_size,
                                                    scale=args.random_scale,
                                                    mirror=args.random_mirror,
                                                    mean=datas['mean']),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True,
                                  drop_last=True)
    valLoader = data.DataLoader(CityscapesValDataSet(args.data_dir,
                                                     args.val_data_list,
                                                     f_scale=1,
                                                     mean=datas['mean']),
                                batch_size=1,
                                shuffle=True,
                                num_workers=args.num_workers,
                                pin_memory=True,
                                drop_last=True)

    start_epoch = 0
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['model'])
            #model.load_state_dict(convert_state_dict(checkpoint['model']))
            print("=====> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=====> no checkpoint found at '{}'".format(args.resume))

    model.train()
    cudnn.benchmark = True

    logFileLoc = args.savedir + args.logFile
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
        logger.write("\nGlobal configuration as follows:")
        for key, value in vars(args).items():
            logger.write("\n{:16} {}".format(key, value))
        logger.write("\nParameters: %s" % (str(total_paramters)))
        logger.write(
            "\n%s\t\t%s\t\t%s\t\t%s\t\t%s\t\t" %
            ('Epoch', 'Loss(Tr)', 'Loss(val)', 'mIOU (tr)', 'mIOU (val)'))
    else:
        logger = open(logFileLoc, 'w')
        logger.write("Global configuration as follows:")
        for key, value in vars(args).items():
            logger.write("\n{:16} {}".format(key, value))
        logger.write("\nParameters: %s" % (str(total_paramters)))
        logger.write(
            "\n%s\t\t%s\t\t%s\t\t%s\t\t%s\t\t" %
            ('Epoch', 'Loss(Tr)', 'Loss(val)', 'mIOU (tr)', 'mIOU (val)'))
    logger.flush()

    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr, (0.9, 0.999),
                                 eps=1e-08,
                                 weight_decay=5e-4)

    print('=====> beginning training')
    for epoch in range(start_epoch, args.max_epochs):
        #training
        lossTr, per_class_iu_tr, mIOU_tr, lr = train(args, trainLoader, model,
                                                     criteria, optimizer,
                                                     epoch)

        #validation
        if epoch % 50 == 0:
            mIOU_val, per_class_iu = val(args, valLoader, model, criteria)
            # record train information
            logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.7f" %
                         (epoch, lossTr, mIOU_tr, mIOU_val, lr))
            logger.flush()
            print("Epoch : " + str(epoch) + ' Details')
            print(
                "\nEpoch No.: %d\tTrain Loss = %.4f\t mIOU(tr) = %.4f\t mIOU(val) = %.4f\t lr= %.6f"
                % (epoch, lossTr, mIOU_tr, mIOU_val, lr))
        else:
            # record train information
            logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.7f" %
                         (epoch, lossTr, mIOU_tr, lr))
            logger.flush()
            print("Epoch : " + str(epoch) + ' Details')
            print(
                "\nEpoch No.: %d\tTrain Loss = %.4f\t mIOU(tr) = %.4f\t lr= %.6f"
                % (epoch, lossTr, mIOU_tr, lr))

        #save the model
        model_file_name = args.savedir + '/model_' + str(epoch + 1) + '.pth'
        state = {"epoch": epoch + 1, "model": model.state_dict()}
        if epoch > args.max_epochs - 10:
            torch.save(state, model_file_name)
        elif not epoch % 20:
            torch.save(state, model_file_name)

    logger.close()
Beispiel #10
0
def train_model(args):
    """
    args:
       args: global arguments
    """
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    print("=====> input size:{}".format(input_size))

    print(args)

    if args.cuda:
        print("=====> use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    # set the seed
    setup_seed(GLOBAL_SEED)
    print("=====> set Global Seed: ", GLOBAL_SEED)

    cudnn.enabled = True
    print("=====> building network")

    # build the model and initialization
    model = build_model(args.model, num_classes=args.classes)
    init_weight(model,
                nn.init.kaiming_normal_,
                nn.BatchNorm2d,
                1e-3,
                0.1,
                mode='fan_in')

    print("=====> computing network parameters and FLOPs")
    total_paramters = netParams(model)
    print("the number of parameters: %d ==> %.2f M" %
          (total_paramters, (total_paramters / 1e6)))

    # load data and data augmentation
    datas, trainLoader, valLoader = build_dataset_train(
        args.dataset, input_size, args.batch_size, args.train_type,
        args.random_scale, args.random_mirror, args.num_workers)

    print('=====> Dataset statistics')
    print("data['classWeights']: ", datas['classWeights'])
    print('mean and std: ', datas['mean'], datas['std'])

    # define loss function, respectively
    weight = torch.from_numpy(datas['classWeights'])

    if args.dataset == 'camvid':
        criteria = CrossEntropyLoss2d(weight=weight, ignore_label=ignore_label)
    elif args.dataset == 'cityscapes':
        min_kept = int(args.batch_size // len(args.gpus) * h * w // 16)
        criteria = ProbOhemCrossEntropy2d(use_weight=True,
                                          ignore_label=ignore_label,
                                          thresh=0.7,
                                          min_kept=min_kept)
    else:
        raise NotImplementedError(
            "This repository now supports two datasets: cityscapes and camvid, %s is not included"
            % args.dataset)

    if args.cuda:
        criteria = criteria.cuda()
        if torch.cuda.device_count() > 1:
            print("torch.cuda.device_count()=", torch.cuda.device_count())
            args.gpu_nums = torch.cuda.device_count()
            model = nn.DataParallel(model).cuda()  # multi-card data parallel
        else:
            args.gpu_nums = 1
            print("single GPU for training")
            model = model.cuda()  # 1-card data parallel

    args.savedir = (args.dataset + '/' + args.savedir + args.model + 'bs' +
                    str(args.batch_size) + "_" + str(args.train_type) + '/')

    if not os.path.exists(args.savedir):
        os.makedirs(args.savedir)

    start_epoch = 0

    # continue training
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['model'])
            # model.load_state_dict(convert_state_dict(checkpoint['model']))
            print("=====> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=====> no checkpoint found at '{}'".format(args.resume))

    model.train()
    cudnn.benchmark = True

    logFileLoc = args.savedir + args.logFile
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
    else:
        logger = open(logFileLoc, 'w')
        logger.write("Parameters: %s Seed: %s" %
                     (str(total_paramters), GLOBAL_SEED))
        logger.write("\n%s\t\t%s\t%s\t%s" %
                     ('Epoch', 'Loss(Tr)', 'mIOU (val)', 'lr'))
    logger.flush()

    # define optimization criteria
    if args.dataset == 'camvid':
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            model.parameters()),
                                     args.lr, (0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=2e-4)

    elif args.dataset == 'cityscapes':
        #optimizer = torch.optim.SGD(
        #filter(lambda p: p.requires_grad, model.parameters()), args.lr, momentum=0.9, weight_decay=1e-4)
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            model.parameters()),
                                     args.lr, (0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=1e-5)

    lossTr_list = []
    epoches = []
    mIOU_val_list = []

    print('=====> beginning training')
    for epoch in range(start_epoch, args.max_epochs):
        # training
        lossTr, lr = train(args, trainLoader, model, criteria, optimizer,
                           epoch)
        lossTr_list.append(lossTr)

        # validation
        if epoch % 30 == 0 or epoch == (args.max_epochs - 1):
            epoches.append(epoch)
            mIOU_val, per_class_iu = val(args, valLoader, model)
            mIOU_val_list.append(mIOU_val)
            # record train information
            logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.7f" %
                         (epoch, lossTr, mIOU_val, lr))
            logger.flush()
            print("Epoch : " + str(epoch) + ' Details')
            print(
                "Epoch No.: %d\tTrain Loss = %.4f\t mIOU(val) = %.4f\t lr= %.6f\n"
                % (epoch, lossTr, mIOU_val, lr))
        else:
            # record train information
            logger.write("\n%d\t\t%.4f\t\t\t\t%.7f" % (epoch, lossTr, lr))
            logger.flush()
            print("Epoch : " + str(epoch) + ' Details')
            print("Epoch No.: %d\tTrain Loss = %.4f\t lr= %.6f\n" %
                  (epoch, lossTr, lr))

        # save the model
        model_file_name = args.savedir + '/model_' + str(epoch + 1) + '.pth'
        state = {"epoch": epoch + 1, "model": model.state_dict()}

        if epoch >= args.max_epochs - 10:
            torch.save(state, model_file_name)
        elif not epoch % 100:
            torch.save(state, model_file_name)

        # draw plots for visualization
        if epoch % 30 == 0 or epoch == (args.max_epochs - 1):
            # Plot the figures per 50 epochs
            fig1, ax1 = plt.subplots(figsize=(11, 8))

            ax1.plot(range(start_epoch, epoch + 1), lossTr_list)
            ax1.set_title("Average training loss vs epochs")
            ax1.set_xlabel("Epochs")
            ax1.set_ylabel("Current loss")

            plt.savefig(args.savedir + "loss_vs_epochs.png")

            plt.clf()

            fig2, ax2 = plt.subplots(figsize=(11, 8))

            ax2.plot(epoches, mIOU_val_list, label="Val IoU")
            ax2.set_title("Average IoU vs epochs")
            ax2.set_xlabel("Epochs")
            ax2.set_ylabel("Current IoU")
            plt.legend(loc='lower right')

            plt.savefig(args.savedir + "iou_vs_epochs.png")

            plt.close('all')

    logger.close()
Beispiel #11
0
def test_model(args):
    """
    main function for testing 
    args:
       args: global arguments
    """
    print("=====> Check if the cached file exists ")
    if not os.path.isfile(args.inform_data_file):
        print("%s is not found" % (args.inform_data_file))
        dataCollect = CamVidTrainInform(
            args.data_dir,
            args.classes,
            train_set_file=args.dataset_list,
            inform_data_file=args.inform_data_file
        )  #collect mean std, weigth_class information
        datas = dataCollect.collectDataAndSave()
        if datas is None:
            print('Error while pickling data. Please check.')
            exit(-1)
    else:
        print("%s exists" % (args.inform_data_file))
        datas = pickle.load(open(args.inform_data_file, "rb"))

    print(args)
    global network_type

    if args.cuda:
        print("=====> Use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    args.seed = random.randint(1, 10000)
    print("Random Seed: ", args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
    cudnn.enabled = True

    M = args.M
    N = args.N
    model = CGNet.Context_Guided_Network(classes=args.classes, M=M, N=N)
    network_type = "CGNet"
    print("=====> current architeture:  CGNet_M%sN%s" % (M, N))
    total_paramters = netParams(model)
    print("the number of parameters: " + str(total_paramters))
    print("data['classWeights']: ", datas['classWeights'])
    weight = torch.from_numpy(datas['classWeights'])
    print("=====> Dataset statistics")
    print("mean and std: ", datas['mean'], datas['std'])

    # define optimization criteria
    criteria = CrossEntropyLoss2d(weight, args.ignore_label)
    if args.cuda:
        model = model.cuda()
        criteria = criteria.cuda()

    #load test set
    train_transform = transforms.Compose([transforms.ToTensor()])
    testLoader = data.DataLoader(CamVidValDataSet(args.data_dir,
                                                  args.test_data_list,
                                                  f_scale=1,
                                                  mean=datas['mean']),
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True,
                                 drop_last=True)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=====> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            #model.load_state_dict(convert_state_dict(checkpoint['model']))
            model.load_state_dict(checkpoint['model'])
        else:
            print("=====> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    print("=====> beginning test")
    print("length of test set:", len(testLoader))
    mIOU_val, per_class_iu = test(args, testLoader, model, criteria)
    print(mIOU_val)
    print(per_class_iu)
Beispiel #12
0
def train(args, model, enc=False):
    # image transform
    input_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
    ])

    data_kwargs = {
        'dataset_root': args.datadir,
        'transform': input_transform,
        'base_size': args.base_size,
        'crop_size': args.crop_size,
        'encode': enc
    }
    train_dataset = get_segmentation_dataset('ade20k',
                                             split='train',
                                             mode='train',
                                             **data_kwargs)
    val_dataset = get_segmentation_dataset('ade20k',
                                           split='val',
                                           mode='val',
                                           **data_kwargs)

    train_sampler = make_data_sampler(train_dataset,
                                      shuffle=True,
                                      distributed=False)
    train_batch_sampler = make_batch_data_sampler(train_sampler,
                                                  args.batch_size)
    val_sampler = make_data_sampler(val_dataset,
                                    shuffle=False,
                                    distributed=False)
    val_batch_sampler = make_batch_data_sampler(val_sampler, args.batch_size)

    loader = data.DataLoader(dataset=train_dataset,
                             batch_sampler=train_batch_sampler,
                             num_workers=args.num_workers,
                             pin_memory=True)
    loader_val = data.DataLoader(dataset=val_dataset,
                                 batch_sampler=val_batch_sampler,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    criterion = CrossEntropyLoss2d()
    print(type(criterion))

    savedir = f'../save/{args.savedir}'

    if (enc):
        automated_log_path = savedir + "/automated_log_encoder.txt"
        modeltxtpath = savedir + "/model_encoder.txt"
    else:
        automated_log_path = savedir + "/automated_log.txt"
        modeltxtpath = savedir + "/model.txt"

    if (not os.path.exists(automated_log_path)
        ):  #dont add first line if it exists
        with open(automated_log_path, "a") as myfile:
            myfile.write(
                "Epoch\t\tTrain-loss\t\tTest-loss\t\tTrain-IoU\t\tTest-IoU\t\tlearningRate"
            )

    with open(modeltxtpath, "w") as myfile:
        myfile.write(str(model))

    optimizer = Adam(model.parameters(),
                     args.lr, (0.9, 0.999),
                     eps=1e-08,
                     weight_decay=1e-4)

    start_epoch = 1
    best_acc = 0.0
    if args.resume:
        if enc:
            filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar'
        else:
            filenameCheckpoint = savedir + '/checkpoint.pth.tar'

        assert os.path.exists(filenameCheckpoint)
        checkpoint = torch.load(filenameCheckpoint)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        best_acc = checkpoint['best_acc']
        print("=> Loaded checkpoint at epoch {})".format(checkpoint['epoch']))

    # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5) # set up scheduler     ## scheduler 1
    lambda1 = lambda epoch: pow(
        (1 - ((epoch - 1) / args.num_epochs)), 0.7)  ## scheduler 2
    scheduler = lr_scheduler.LambdaLR(optimizer,
                                      lr_lambda=lambda1)  ## scheduler 2

    for epoch in range(start_epoch, args.num_epochs + 1):
        print("----- TRAINING - EPOCH", epoch, "-----", " LR",
              optimizer.param_groups[0]['lr'], "-----")

        epoch_loss = []
        time_train = []

        doIouTrain = args.iouTrain
        doIouVal = args.iouVal

        if (doIouTrain):
            iouEvalTrain = iouEval(args.NUM_CLASSES)

        usedLr = optimizer.param_groups[0]['lr']

        model.train()
        total_train_step = len(train_dataset) // args.batch_size
        total_val_step = len(val_dataset) // args.batch_size
        for step, (images, labels, _) in enumerate(loader):
            start_time = time.time()

            imgs_batch = images.shape[0]
            if imgs_batch != args.batch_size:
                break

            if args.cuda:
                inputs = images.cuda()
                targets = labels.cuda()

            outputs = model(inputs, only_encode=enc)

            optimizer.zero_grad()
            loss = criterion(outputs, targets)

            loss.backward()
            optimizer.step()
            scheduler.step(epoch)  ## scheduler 2

            epoch_loss.append(loss.item())
            time_train.append(time.time() - start_time)

            if (doIouTrain):
                targets = torch.unsqueeze(targets, 1)
                iouEvalTrain.addBatch(
                    outputs.max(1)[1].unsqueeze(1).data, targets.data)
            if args.steps_loss > 0 and step % args.steps_loss == 0:
                average = sum(epoch_loss) / len(epoch_loss)
                print(
                    f'loss: {average:0.4} (epoch: {epoch}, step: {step}/{total_train_step})',
                    "// Remaining time: %.1f s" %
                    ((total_train_step - step) * sum(time_train) /
                     len(time_train)))

        average_epoch_loss_train = sum(epoch_loss) / len(epoch_loss)

        iouTrain = 0
        if (doIouTrain):
            iouTrain, iou_classes = iouEvalTrain.getIoU()
            print("EPOCH IoU on TRAIN set: ", iouTrain.item() * 100, "%")
        print("----- VALIDATING - EPOCH", epoch, "-----")
        model.eval()
        epoch_loss_val = []
        time_val = []

        if (doIouVal):
            iouEvalVal = iouEval(args.NUM_CLASSES)

        for step, (images, labels, _) in enumerate(loader_val):
            start_time = time.time()

            imgs_batch = images.shape[0]
            if imgs_batch != args.batch_size:
                break
            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()
            with torch.no_grad():
                inputs = Variable(images)
                targets = Variable(labels)
                outputs = model(inputs, only_encode=enc)
                loss = criterion(outputs, targets)
            epoch_loss_val.append(loss.item())
            time_val.append(time.time() - start_time)

            #Add batch to calculate TP, FP and FN for iou estimation
            if (doIouVal):
                targets = torch.unsqueeze(targets, 1)
                iouEvalVal.addBatch(
                    outputs.max(1)[1].unsqueeze(1).data, targets.data)
            if args.steps_loss > 0 and step % args.steps_loss == 0:
                average = sum(epoch_loss_val) / len(epoch_loss_val)
                print(
                    f'VAL loss: {average:0.4} (epoch: {epoch}, step: {step}/{total_val_step})',
                    "// Remaining time: %.1f s" %
                    ((total_val_step - step) * sum(time_val) / len(time_val)))

        average_epoch_loss_val = sum(epoch_loss_val) / len(epoch_loss_val)
        # scheduler.step(average_epoch_loss_val, epoch)  ## scheduler 1   # update lr if needed

        iouVal = 0
        if (doIouVal):
            iouVal, iou_classes = iouEvalVal.getIoU()
            print("EPOCH IoU on VAL set: ", iouVal.item() * 100, "%")

        # remember best valIoU and save checkpoint
        if iouVal == 0:
            current_acc = -average_epoch_loss_val
        else:
            current_acc = iouVal
        print('best acc:', best_acc, ' current acc:', current_acc.item())
        is_best = current_acc > best_acc
        best_acc = max(current_acc, best_acc)
        if enc:
            filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar'
            filenameBest = savedir + '/model_best_enc.pth.tar'
        else:
            filenameCheckpoint = savedir + '/checkpoint.pth.tar'
            filenameBest = savedir + '/model_best.pth.tar'
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': str(model),
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            }, is_best, filenameCheckpoint, filenameBest)

        #SAVE MODEL AFTER EPOCH
        if (enc):
            filename = f'{savedir}/model_encoder-{epoch:03}.pth'
            filenamebest = f'{savedir}/model_encoder_best.pth'
        else:
            filename = f'{savedir}/model-{epoch:03}.pth'
            filenamebest = f'{savedir}/model_best.pth'
        if args.epochs_save > 0 and step > 0 and step % args.epochs_save == 0:
            torch.save(model.state_dict(), filename)
            print(f'save: {filename} (epoch: {epoch})')
        if (is_best):
            torch.save(model.state_dict(), filenamebest)
            print(f'save: {filenamebest} (epoch: {epoch})')
            if (not enc):
                with open(savedir + "/best.txt", "w") as myfile:
                    myfile.write("Best epoch is %d, with Val-IoU= %.4f" %
                                 (epoch, iouVal))
            else:
                with open(savedir + "/best_encoder.txt", "w") as myfile:
                    myfile.write("Best epoch is %d, with Val-IoU= %.4f" %
                                 (epoch, iouVal))

        with open(automated_log_path, "a") as myfile:
            myfile.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.8f" %
                         (epoch, average_epoch_loss_train,
                          average_epoch_loss_val, iouTrain, iouVal, usedLr))
    return (model)
Beispiel #13
0
                                  shuffle=True,
                                  collate_fn=collate_fn)
    valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False)
    train_len = len(train_dataloader)
    valid_len = len(valid_dataloader)
    # network
    net = MODELS['StackHourglassNetMTL'](in_channels=4, task1_classes=1)
    if args.load_checkpoint is not None and args.mode != 'train':
        net.load_state_dict(
            torch.load(args.load_checkpoint, map_location='cpu'))
    net.to(device)
    optimizor = torch.optim.Adam(list(net.parameters()),
                                 lr=1e-4,
                                 weight_decay=1e-5)
    orien_loss = CrossEntropyLoss2d(size_average=True,
                                    ignore_index=255,
                                    reduce=True)
    road_loss = mIoULoss(n_classes=1, device=device)
    criterion = {
        'orien_loss': orien_loss,
        'road_loss': road_loss,
        'ce': nn.BCEWithLogitsLoss()
    }
    writer = SummaryWriter('./records/tensorboard/seg')

    for i in range(args.epochs):
        if args.mode != 'train':
            val(args, 0, net, valid_dataloader, 0, valid_len, writer)
            break
        train(args, i, net, train_dataloader, train_len, optimizor, criterion,
              writer, valid_dataloader, valid_len)
Beispiel #14
0
def main(args):
    # random.seed(args.random_seed)
    # print(args)
    cudnn.enabled = True

    # create network
    """
    if args.generatormodel == 'TwinsAdvNet_D':
        model = get_model(name=args.generatormodel, num_classes = args.num_classes)
        #model = TwinsAdvNet_DL(num_classes = args.num_classes)
        if args.restore_from[:4] == 'http' :
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
    """

    model = get_model(name=args.generatormodel, num_classes=args.num_classes)
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        #print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            # print('copy {}'.format(name))
    model.load_state_dict(new_params)

    model.train()

    # load nets into gpu
    if args.num_gpus > 1:
        model = torch.nn.DataParallel(model, device_ids=range(args.num_gpus))
    model.cuda()

    cudnn.benchmark = True  # acceleration

    # init D
    model_D = Discriminator(num_classes=args.num_classes)

    model_D.train()

    if args.num_gpus > 1:
        model_D = torch.nn.DataParallel(model_D,
                                        device_ids=range(args.num_gpus))
    model_D.cuda()

    #train_dataset = MITSceneParsingDataset(args.list_train, args, is_train=1)
    train_dataset = MITSceneParsingDataset(args.list_train,
                                           args,
                                           max_iters=args.max_iters,
                                           is_train=1)

    #train_dataset_size = len(train_dataset)
    #args.epoch_iters = int(train_dataset_size / (args.batch_size * args.num_gpus))
    #print('train_dataset_size = {} | 1 Epoch = {} iters'.format(train_dataset_size, args.epoch_iters))

    trainloader = data.DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=int(args.workers),
                                  pin_memory=True,
                                  drop_last=True)

    val_dataset = MITSceneParsingDataset(args.list_val,
                                         args,
                                         max_sample=args.num_val,
                                         is_train=0)

    val_loader = data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,  # False
        num_workers=8,
        pin_memory=True,
        drop_last=True)

    trainloader_iter = enumerate(trainloader)

    # loss / bilinear upsampling
    #bce_loss = BCEWithLogitsLoss2d()
    bce_loss = torch.nn.BCEWithLogitsLoss()  # only 0, 1

    #crit = nn.NLLLoss2d(ignore_index=-1)  # ade20k
    crit = CrossEntropyLoss2d(ignore_index=-1)

    # trainloader_iter = iter(trainloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation networks
    """
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    """
    #optimizer.zero_grad()  #
    """
    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    
    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    #optimizer_D.zero_grad()
    """
    optimizer, optimizer_D = create_optimizers(model, model_D, crit, args)
    #optimizer.zero_gard()
    #optimizer_D.zero_gard()

    # interp = nn.Upsample(size=(384, 384), mode='bilinear')
    interp = nn.Upsample(size=(args.imgSize, args.imgSize), mode='bilinear')
    # interp_x1 = nn.Upsample(size=(384, 384), mode='bilinear')  # G1
    # interp_x2 = nn.Upsample(size=(384, 384), mode='bilinear')  # G2

    # main loop
    history = {
        split: {
            'epoch': [],
            'loss_pred_outputs': [],
            'acc_pred_outputs': []
        }
        for split in ('train', 'val')
    }

    # initial eval
    evaluate(model, val_loader, interp, crit, history, 0, args)
    for epoch in range(args.start_epoch, args.num_epoches + 1):
        train(model, model_D, trainloader_iter, interp, optimizer, optimizer_D,
              crit, bce_loss, history, epoch, args)

        if epoch % args.eval_epoch == 0:
            evaluate(model, val_loader, interp, crit, history, epoch, args)

    end_time = timeit.default_timer()
    print(' running time(s): [{0:.4f} seconds]'.format(
        (end_time - start_time)))
def main():
    training_batch_size = 8
    validation_batch_size = 8
    epoch_num = 200
    iter_freq_print_training_log = 50
    lr = 1e-4

    net = SegNet(pretrained=True, num_classes=num_classes).cuda()
    curr_epoch = 0

    # net = FCN8VGG(pretrained=False, num_classes=num_classes).cuda()
    # snapshot = 'epoch_41_validation_loss_2.1533_mean_iu_0.5225.pth'
    # net.load_state_dict(torch.load(os.path.join(ckpt_path, snapshot)))
    # split_res = snapshot.split('_')
    # curr_epoch = int(split_res[1])

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    train_simultaneous_transform = SimultaneousCompose([
        SimultaneousRandomHorizontallyFlip(),
        SimultaneousRandomScale((0.9, 1.1)),
        SimultaneousRandomCrop((300, 500))
    ])
    train_transform = transforms.Compose([
        RandomGaussianBlur(),
        transforms.ToTensor(),
        transforms.Normalize(*mean_std)
    ])
    val_simultaneous_transform = SimultaneousCompose(
        [SimultaneousScale((300, 500))])
    val_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(*mean_std)])
    restore = transforms.Compose(
        [DeNormalize(*mean_std),
         transforms.ToPILImage()])

    train_set = VOC(train_path,
                    simultaneous_transform=train_simultaneous_transform,
                    transform=train_transform,
                    target_transform=MaskToTensor())
    train_loader = DataLoader(train_set,
                              batch_size=training_batch_size,
                              num_workers=8,
                              shuffle=True)
    val_set = VOC(val_path,
                  simultaneous_transform=val_simultaneous_transform,
                  transform=val_transform,
                  target_transform=MaskToTensor())
    val_loader = DataLoader(val_set,
                            batch_size=validation_batch_size,
                            num_workers=8)

    criterion = CrossEntropyLoss2d(ignored_label=ignored_label)
    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ]
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'weight_decay':
        5e-4
    }],
                          lr=lr,
                          momentum=0.9,
                          nesterov=True)

    if not os.path.exists(ckpt_path):
        os.mkdir(ckpt_path)

    best = [1e9, -1, -1]  # [best_val_loss, best_mean_iu, best_epoch]

    for epoch in range(curr_epoch, epoch_num):
        train(train_loader, net, criterion, optimizer, epoch,
              iter_freq_print_training_log)
        if (epoch + 1) % 20 == 0:
            lr /= 3
            adjust_lr(optimizer, lr)
        validate(epoch, val_loader, net, criterion, restore, best)
Beispiel #16
0
    def __init__(self, args):
        self.args = args
        self.mode = args.mode
        self.epochs = args.epochs
        self.dataset = args.dataset
        self.data_path = args.data_path
        self.train_crop_size = args.train_crop_size
        self.eval_crop_size = args.eval_crop_size
        self.stride = args.stride
        self.batch_size = args.train_batch_size
        self.train_data = AerialDataset(crop_size=self.train_crop_size,
                                        dataset=self.dataset,
                                        data_path=self.data_path,
                                        mode='train')
        self.train_loader = DataLoader(self.train_data,
                                       batch_size=self.batch_size,
                                       shuffle=True,
                                       num_workers=2)
        self.eval_data = AerialDataset(dataset=self.dataset,
                                       data_path=self.data_path,
                                       mode='val')
        self.eval_loader = DataLoader(self.eval_data,
                                      batch_size=1,
                                      shuffle=False,
                                      num_workers=2)

        if self.dataset == 'Potsdam':
            self.num_of_class = 6
            self.epoch_repeat = get_test_times(6000, 6000,
                                               self.train_crop_size,
                                               self.train_crop_size)
        elif self.dataset == 'UDD5':
            self.num_of_class = 5
            self.epoch_repeat = get_test_times(4000, 3000,
                                               self.train_crop_size,
                                               self.train_crop_size)
        elif self.dataset == 'UDD6':
            self.num_of_class = 6
            self.epoch_repeat = get_test_times(4000, 3000,
                                               self.train_crop_size,
                                               self.train_crop_size)
        else:
            raise NotImplementedError

        if args.model == 'FCN':
            self.model = models.FCN8(num_classes=self.num_of_class)
        elif args.model == 'DeepLabV3+':
            self.model = models.DeepLab(num_classes=self.num_of_class,
                                        backbone='resnet')
        elif args.model == 'GCN':
            self.model = models.GCN(num_classes=self.num_of_class)
        elif args.model == 'UNet':
            self.model = models.UNet(num_classes=self.num_of_class)
        elif args.model == 'ENet':
            self.model = models.ENet(num_classes=self.num_of_class)
        elif args.model == 'D-LinkNet':
            self.model = models.DinkNet34(num_classes=self.num_of_class)
        else:
            raise NotImplementedError

        if args.loss == 'CE':
            self.criterion = CrossEntropyLoss2d()
        elif args.loss == 'LS':
            self.criterion = LovaszSoftmax()
        elif args.loss == 'F':
            self.criterion = FocalLoss()
        elif args.loss == 'CE+D':
            self.criterion = CE_DiceLoss()
        else:
            raise NotImplementedError

        self.schedule_mode = args.schedule_mode
        self.optimizer = opt.AdamW(self.model.parameters(), lr=args.lr)
        if self.schedule_mode == 'step':
            self.scheduler = opt.lr_scheduler.StepLR(self.optimizer,
                                                     step_size=30,
                                                     gamma=0.1)
        elif self.schedule_mode == 'miou' or self.schedule_mode == 'acc':
            self.scheduler = opt.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                                mode='max',
                                                                patience=10,
                                                                factor=0.1)
        elif self.schedule_mode == 'poly':
            iters_per_epoch = len(self.train_loader)
            self.scheduler = Poly(self.optimizer,
                                  num_epochs=args.epochs,
                                  iters_per_epoch=iters_per_epoch)
        else:
            raise NotImplementedError

        self.evaluator = Evaluator(self.num_of_class)

        self.model = nn.DataParallel(self.model)

        self.cuda = args.cuda
        if self.cuda is True:
            self.model = self.model.cuda()

        self.resume = args.resume
        self.finetune = args.finetune
        assert not (self.resume != None and self.finetune != None)

        if self.resume != None:
            print("Loading existing model...")
            if self.cuda:
                checkpoint = torch.load(args.resume)
            else:
                checkpoint = torch.load(args.resume, map_location='cpu')
            self.model.load_state_dict(checkpoint['parameters'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.start_epoch = checkpoint['epoch'] + 1
            #start from next epoch
        elif self.finetune != None:
            print("Loading existing model...")
            if self.cuda:
                checkpoint = torch.load(args.finetune)
            else:
                checkpoint = torch.load(args.finetune, map_location='cpu')
            self.model.load_state_dict(checkpoint['parameters'])
            self.start_epoch = checkpoint['epoch'] + 1
        else:
            self.start_epoch = 1
        if self.mode == 'train':
            self.writer = SummaryWriter(comment='-' + self.dataset + '_' +
                                        self.model.__class__.__name__ + '_' +
                                        args.loss)
        self.init_eval = args.init_eval