Ejemplo n.º 1
0
def CreateDataset(opt):
    dataset = None
    from data.cityscapes import Cityscapes
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ])
    dataset = Cityscapes(root=opt.dataroot,
                         transforms=transform,
                         resolution=256)

    print("dataset [%s] was created" % (dataset.name()))
    #     dataset.initialize(opt)
    return dataset
Ejemplo n.º 2
0
    def setup(self, stage):
        train = Cityscapes(
            root="../../maskgan/data/cityscapes/",
            split="train",
            mode="fine",
            target_type="semantic",
            transform=None,
            target_transform=None,
            # transforms=transform,
        )

        val = Cityscapes(
            root="../../maskgan/data/cityscapes/",
            split="val",
            mode="fine",
            target_type="semantic",
            transform=None,
            target_transform=None,
            # transforms=transform,
        )

        self.train_dataset = train
        self.val_dataset = val
Ejemplo n.º 3
0
def predict():
    image_transform = transforms.Compose(
        [transforms.Resize(target_size),
         transforms.ToTensor()])

    label_transform = transforms.Compose(
        [transforms.Resize(target_size),
         ext_transforms.PILToLongTensor()])

    # Get selected dataset
    # Load the training set as tensors
    train_set = Cityscapes(data_dir,
                           mode='test',
                           transform=image_transform,
                           label_transform=label_transform)

    class_encoding = train_set.color_encoding

    num_classes = len(class_encoding)
    model = ENet(num_classes).to(device)

    # Initialize a optimizer just so we can retrieve the model from the
    # checkpoint
    optimizer = optim.Adam(model.parameters())

    # Load the previoulsy saved model state to the ENet model
    model = utils.load_checkpoint(model, optimizer, 'save',
                                  'ENet_cityscapes_mine.pth')[0]
    # print(model)

    image = Image.open('images/mainz_000000_008001_leftImg8bit.png')
    images = Variable(image_transform(image).to(device).unsqueeze(0))
    image = np.array(image)

    # Make predictions!
    predictions = model(images)
    _, predictions = torch.max(predictions.data, 1)
    # 0~18
    prediction = predictions.cpu().numpy()[0] - 1

    mask_color = np.asarray(label_to_color_image(prediction, 'cityscapes'),
                            dtype=np.uint8)
    mask_color = cv2.resize(mask_color, (image.shape[1], image.shape[0]))
    print(image.shape)
    print(mask_color.shape)
    res = cv2.addWeighted(image, 0.3, mask_color, 0.7, 0.6)
    # cv2.imshow('rr', mask_color)
    cv2.imshow('combined', res)
    cv2.waitKey(0)
Ejemplo n.º 4
0
if evaluating:
    trans_train = trans_val
else:
    trans_train = Compose([
        Open(),
        RandomFlip(),
        RandomSquareCropAndScale(random_crop_size,
                                 ignore_id=num_classes,
                                 mean=mean_rgb),
        SetTargetSize(target_size=target_size_crops,
                      target_size_feats=target_size_crops_feats),
        Tensor(),
    ])

dataset_train = Cityscapes(root, transforms=trans_train, subset='train')
dataset_val = Cityscapes(root, transforms=trans_val, subset='val')

resnet = resnet18(pretrained=True,
                  efficient=False,
                  mean=mean,
                  std=std,
                  scale=scale)
model = SemsegModel(resnet, num_classes)

if pruning:
    model.load_state_dict(
        torch.load('weights/rn18_single_scale/model_best.pt'))

if evaluating:
    model.load_state_dict(
Ejemplo n.º 5
0
)

if evaluating:
    trans_train = trans_val
else:
    trans_train = Compose(
        [Open(copy_labels=False),
         RandomFlip(),
         RandomSquareCropAndScale(random_crop_size, ignore_id=ignore_id, mean=mean_rgb),
         SetTargetSize(target_size=target_size_crops, target_size_feats=target_size_crops_feats),
         LabelDistanceTransform(num_classes=num_classes, reduce=True, bins=dist_trans_bins,
                                alphas=dist_trans_alphas, ignore_id=ignore_id),
         Tensor(),
         ])

dataset_train = Cityscapes(root, transforms=trans_train, subset='train', labels_dir='labels')
dataset_val = Cityscapes(root, transforms=trans_val, subset='val', labels_dir='labels')

for dset in [dataset_train, dataset_val]:
    for atter in ['class_info', 'color_info']:
        setattr(dset, atter, getattr(Cityscapes, atter))

resnet = resnet34(pretrained=True, k_up=3, scale=scale, mean=mean, std=std, output_stride=8, efficient=False)
model = SemsegModel(resnet, num_classes, k=1, bias=True)

if pruning:
    model.load_state_dict(torch.load("weights/76-66_resnet34x8/stored/model_best.pt"), strict=False)

if evaluating:
    model.load_state_dict(torch.load("weights/76-66_resnet34x8/stored/model_best.pt"), strict=False)
else:
Ejemplo n.º 6
0
alphas = [1.]
target_size = ts = (2048, 1024)
target_size_feats = (ts[0] // 4, ts[1] // 4)
scale = 255
mean = Cityscapes.mean
std = Cityscapes.std
nw = 1

trans_train = trans_val = Compose(
    [Open(),
     RemapLabels(Cityscapes.map_to_id, Cityscapes.num_classes),
     Pyramid(alphas=alphas),
     SetTargetSize(target_size=target_size, target_size_feats=target_size_feats),
     Normalize(scale, mean, std),
     Tensor(),
     ]
)

if __name__ == "__main__":
    root = Path('datasets/Cityscapes')
    dataset_val = Cityscapes(root, transforms=trans_val, subset='val')
    loader_val = DataLoader(dataset_val, batch_size=1, collate_fn=custom_collate, num_workers=nw)
    eval_loaders = [(loader_val, 'val')]

    class_info = dataset_val.class_info
    

    for loader, name in eval_loaders:
        iou, per_class_iou = run_acc_test( loader, class_info )
        print(f'{name}: {iou:.2f}')
Ejemplo n.º 7
0
def main():
    args = get_args()
    pid = os.getpid()
    device = torch.device("cuda:0" if args.cuda else "cpu")

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    best_mIoU = 0

    #### preparation ###########################################
    from config_seg import config as data_setting
    data_setting.batch_size = args.batch_size
    # if args.src_list is not None:
    #     data_setting.train_source = args.src_list
    # if args.tgt_list is not None:
    #     data_setting.eval_source = args.tgt_list
    train_loader = get_train_loader(data_setting, GTA5, test=False)
    train_loader_iter = iter(train_loader)
    current_optimizee_step, prev_optimizee_step = 0, 0

    model_old = None
    if args.lwf:
        # create a fixed model copy for Life-long learning
        ##### Vgg16 #####
        model_old = vgg16(pretrained=True)
        ###################
        model_old.eval()
        model_old.to(device)
    ############################################################

    ### Agent Settings ########################################
    RANDOM = False # False | True | 'init'
    action_space = np.arange(0, 1.1, 0.1)
    # action_space = np.arange(0, 3); granularity = 0.01
    obs_avg = True
    _window_size = 1
    window_size = 1 if obs_avg else _window_size
    window_shrink_size = 20 # larger: controller will be updated more frequently w.r.t. optimizee_step
    sgd_in_names = ["conv1", "conv2", "conv3", "conv4", "conv5", "FC", "fc_new"]
    coord_size = len(sgd_in_names)
    ob_name_lstm = ["loss", "loss_kl", "step", "fc_mean", "fc_std"]
    ob_name_scalar = []
    obs_shape = (len(ob_name_lstm) * window_size + len(ob_name_scalar) + coord_size, )
    _hidden_size = 20
    hidden_size = _hidden_size * len(ob_name_lstm)
    actor_critic = Policy(coord_size, input_size=(len(ob_name_lstm), len(ob_name_scalar)), action_space=len(action_space), hidden_size=_hidden_size, window_size=window_size)
    actor_critic.to(device)
    actor_critic.eval()

    partial = torch.load("./pretrained/policy_vgg16_segmentation.pth", map_location=lambda storage, loc: storage)
    state = actor_critic.state_dict()
    pretrained_dict = {k: v for k, v in partial.items()}
    state.update(pretrained_dict)
    actor_critic.load_state_dict(state)

    if args.algo == 'reinforce':
        agent = algo.REINFORCE(
            actor_critic,
            args.entropy_coef,
            lr=args.lr_meta,
            eps=args.eps,
            alpha=args.alpha,
            max_grad_norm=args.max_grad_norm)
    elif args.algo == 'a2c':
        agent = algo.A2C_ACKTR(
            actor_critic,
            args.value_loss_coef,
            args.entropy_coef,
            lr=args.lr_meta,
            eps=args.eps,
            alpha=args.alpha,
            max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(
            actor_critic,
            args.clip_param,
            args.ppo_epoch,
            args.num_mini_batch,
            args.value_loss_coef,
            args.entropy_coef,
            lr=args.lr_meta,
            eps=args.eps,
            max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic, args.value_loss_coef, args.entropy_coef, acktr=True)
    ################################################################

    _min_iter = 20
    # reset optmizee
    model, optimizer, current_optimizee_step, prev_optimizee_step = prepare_optimizee(args, sgd_in_names, obs_shape, hidden_size, actor_critic, current_optimizee_step, prev_optimizee_step)

    ##### Logging ###########################
    # Log outputs
    if RANDOM:
        args.name = "Random_GTA5_Min%diter.Step%d.Window%d_batch%d_Epoch%d_LR%.1e.warmpoly_lwf.%d"%\
            (_min_iter, args.num_steps, window_shrink_size, args.batch_size, args.epochs, args.lr, args.lwf)
    else:
        args.name = "metatrain_GTA5_%s.SGD.Gamma%.1f.LRmeta.%.1e.Hidden%d.Loss.avg.exp.Earlystop.%d.Min%diter.Step%d.Window%d_batch%d_Epoch%d_LR%.1e.warmpoly_lwf.%d"%\
            (args.algo, args.gamma, args.lr_meta, _hidden_size, args.early_stop, _min_iter, args.num_steps, window_shrink_size, args.batch_size, args.epochs, args.lr, args.lwf)
        if args.resume:
            args.name += "_resumed"

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    # Log outputs
    directory = "runs/%s/"%(args.name)
    if not os.path.exists(directory):
        os.makedirs(directory)
    filename = directory + 'train.log'
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    rootLogger = logging.getLogger()
    logFormatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s]  %(message)s")
    fileHandler = logging.FileHandler(filename)
    fileHandler.setFormatter(logFormatter)
    rootLogger.addHandler(fileHandler)

    consoleHandler = logging.StreamHandler()
    consoleHandler.setFormatter(logFormatter)
    rootLogger.addHandler(consoleHandler)
    rootLogger.setLevel(logging.INFO)

    writer = SummaryWriter(directory)
    ###########################################

    threds = 1
    evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None), args.num_class, np.array([0.485, 0.456, 0.406]),
                    np.array([0.229, 0.224, 0.225]), model, [1, ], False, devices=0, config=data_setting, threds=threds,
                    verbose=False, save_path=None, show_image=False)

    epoch_size = len(train_loader)
    total_steps = epoch_size*args.epochs
    bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]'
    pbar = tqdm(range(int(epoch_size*args.epochs)), file=sys.stdout, bar_format=bar_format, ncols=100)
    _window_size = max(_min_iter, current_optimizee_step + prev_optimizee_step // window_shrink_size)
    train_loader_iter, obs, loss, loss_kl, fc_mean, fc_std = train_step(args, _window_size, train_loader_iter, train_loader, model, optimizer, obs_avg, args.lr, pbar, current_optimizee_step + prev_optimizee_step, total_steps, model_old=model_old)
    writer.add_scalar("loss/ce", loss, current_optimizee_step + prev_optimizee_step)
    writer.add_scalar("loss/kl", loss_kl, current_optimizee_step + prev_optimizee_step)
    writer.add_scalar("loss/total", loss + loss_kl, current_optimizee_step + prev_optimizee_step)
    writer.add_scalar("fc/mean", fc_mean, current_optimizee_step + prev_optimizee_step)
    writer.add_scalar("fc/std", fc_std, current_optimizee_step + prev_optimizee_step)
    current_optimizee_step += _window_size
    prev_obs = obs.unsqueeze(0)
    prev_hidden = torch.zeros(actor_critic.net.num_recurrent_layers, 1, hidden_size).cuda()
    for epoch in range(args.epochs):
        print("\n===== Epoch %d / %d ====="%(epoch+1, args.epochs))
        print("============= " + args.name + " ================")
        print("============= PID: " + str(pid) + " ================")
        while current_optimizee_step < epoch_size:
            # Sample actions
            with torch.no_grad():
                if not RANDOM:
                    value, action, action_log_prob, recurrent_hidden_states, distribution = actor_critic.act(prev_obs, prev_hidden, deterministic=False)
                    action = action.squeeze(0)
                    action_log_prob = action_log_prob.squeeze(0)
                    value = value.squeeze(0)
                    for idx in range(len(action)):
                        writer.add_scalar("action/%s"%(sgd_in_names[idx]), action[idx], current_optimizee_step + prev_optimizee_step)
                        writer.add_scalar("entropy/%s"%(sgd_in_names[idx]), distribution.distributions[idx].entropy(), current_optimizee_step + prev_optimizee_step)
                        optimizer.param_groups[idx]['lr'] = float(action_space[action[idx]]) * args.lr
                        writer.add_scalar("LR/%s"%(sgd_in_names[idx]), optimizer.param_groups[idx]['lr'], current_optimizee_step + prev_optimizee_step)
                else:
                    if RANDOM is True or RANDOM == 'init':
                        for idx in range(coord_size):
                            optimizer.param_groups[idx]['lr'] = float(choice(action_space)) * args.lr
                    if RANDOM == 'init':
                        RANDOM = 'done'
                    for idx in range(coord_size):
                        writer.add_scalar("LR/%s"%sgd_in_names[idx], optimizer.param_groups[idx]['lr'], current_optimizee_step + prev_optimizee_step)

            # Obser reward and next obs
            _window_size = max(_min_iter, current_optimizee_step + prev_optimizee_step // window_shrink_size)
            _window_size = min(_window_size, epoch_size - current_optimizee_step)
            train_loader_iter, obs, loss, loss_kl, fc_mean, fc_std = train_step(args, _window_size, train_loader_iter, train_loader, model, optimizer, obs_avg, args.lr, pbar, current_optimizee_step + prev_optimizee_step, total_steps, model_old=model_old)
            writer.add_scalar("loss/ce", loss, current_optimizee_step + prev_optimizee_step)
            writer.add_scalar("loss/kl", loss_kl, current_optimizee_step + prev_optimizee_step)
            writer.add_scalar("loss/total", loss + loss_kl, current_optimizee_step + prev_optimizee_step)
            writer.add_scalar("fc/mean", fc_mean, current_optimizee_step + prev_optimizee_step)
            writer.add_scalar("fc/std", fc_std, current_optimizee_step + prev_optimizee_step)
            current_optimizee_step += _window_size
            prev_obs = obs.unsqueeze(0)
            if not RANDOM: prev_hidden = recurrent_hidden_states
        prev_optimizee_step += current_optimizee_step
        current_optimizee_step = 0

        # evaluate on validation set
        torch.cuda.empty_cache()
        mIoU = validate(evaluator, model)
        writer.add_scalar("mIoU", mIoU, epoch)

        # remember best prec@1 and save checkpoint
        is_best = mIoU > best_mIoU
        best_mIoU = max(mIoU, best_mIoU)
        save_checkpoint(args.name, {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_mIoU': best_mIoU,
        }, is_best)

        logging.info('Best accuracy: {mIoU:.3f}'.format(mIoU=best_mIoU))
Ejemplo n.º 8
0
def main():
    global args, best_mIoU
    args = parser.parse_args()
    pid = os.getpid()

    # Log outputs
    args.name = "GTA5_Vgg16_batch%d_512x512_Poly_LR%.1e_1to%.1f_all_lwf.%d_epoch%d" % (
        args.batch_size, args.lr, args.factor, args.lwf, args.epochs)
    if args.resume:
        args.name += "_resumed"
    directory = "runs/%s/" % (args.name)
    if not os.path.exists(directory):
        os.makedirs(directory)
    filename = directory + 'train.log'
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    rootLogger = logging.getLogger()
    logFormatter = logging.Formatter(
        "%(asctime)s [%(levelname)-5.5s]  %(message)s")
    fileHandler = logging.FileHandler(filename)
    fileHandler.setFormatter(logFormatter)
    rootLogger.addHandler(fileHandler)

    consoleHandler = logging.StreamHandler()
    consoleHandler.setFormatter(logFormatter)
    rootLogger.addHandler(consoleHandler)
    rootLogger.setLevel(logging.INFO)

    writer = SummaryWriter(directory)

    from config_seg import config as data_setting
    data_setting.batch_size = args.batch_size
    train_loader = get_train_loader(data_setting, GTA5, test=False)

    ##### Vgg16 #####
    vgg = vgg16(pretrained=True)
    model = FCN_Vgg(n_class=args.num_class)
    model.copy_params_from_vgg16(vgg)
    ###################
    threds = 1
    evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None),
                             args.num_class,
                             np.array([0.485, 0.456, 0.406]),
                             np.array([0.229, 0.224, 0.225]),
                             model, [
                                 1,
                             ],
                             False,
                             devices=args.gpus,
                             config=data_setting,
                             threds=threds,
                             verbose=False,
                             save_path=None,
                             show_image=False)

    # Setup optimizer
    ##### Vgg16 #####
    sgd_in = [
        {
            'params': get_params(model, ["conv1_1", "conv1_2"]),
            'lr': args.factor * args.lr
        },
        {
            'params': get_params(model, ["conv2_1", "conv2_2"]),
            'lr': args.factor * args.lr
        },
        {
            'params': get_params(model, ["conv3_1", "conv3_2", "conv3_3"]),
            'lr': args.factor * args.lr
        },
        {
            'params': get_params(model, ["conv4_1", "conv4_2", "conv4_3"]),
            'lr': args.factor * args.lr
        },
        {
            'params': get_params(model, ["conv5_1", "conv5_2", "conv5_3"]),
            'lr': args.factor * args.lr
        },
        {
            'params': get_params(model, ["fc6", "fc7"]),
            'lr': args.factor * args.lr
        },
        {
            'params':
            get_params(model, [
                "score_fr", "score_pool3", "score_pool4", "upscore2",
                "upscore8", "upscore_pool4"
            ]),
            'lr':
            args.lr
        },
    ]
    base_lrs = [group['lr'] for group in sgd_in]
    optimizer = torch.optim.SGD(sgd_in,
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # Optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_mIoU = checkpoint['best_mIoU']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=ImageClassdata> no checkpoint found at '{}'".format(
                args.resume))

    model = model.cuda()
    model_old = None
    if args.lwf > 0:
        # create a fixed model copy for Life-long learning
        model_old = vgg16(pretrained=True)
        ###################
        for param in model_old.parameters():
            param.requires_grad = False
        model_old.eval()
        model_old.cuda()

    if args.evaluate:
        mIoU = validate(evaluator, model)
        print(mIoU)

    # Main training loop
    iter_max = args.epochs * math.ceil(len(train_loader) / args.iter_size)
    iter_stat = IterNums(iter_max)
    for epoch in range(args.start_epoch, args.epochs):
        logging.info("============= " + args.name + " ================")
        logging.info("============= PID: " + str(pid) + " ================")
        logging.info("Epoch: %d" % (epoch + 1))
        # train for one epoch
        train(args,
              train_loader,
              model,
              optimizer,
              base_lrs,
              iter_stat,
              epoch,
              writer,
              model_old=model_old,
              adjust_lr=epoch < args.epochs)
        # evaluate on validation set
        torch.cuda.empty_cache()
        mIoU = validate(evaluator, model)
        writer.add_scalar("mIoU", mIoU, epoch)
        # remember best mIoU and save checkpoint
        is_best = mIoU > best_mIoU
        best_mIoU = max(mIoU, best_mIoU)
        save_checkpoint(
            directory, {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_mIoU': best_mIoU,
            }, is_best)

    logging.info('Best accuracy: {mIoU:.3f}'.format(mIoU=best_mIoU))
Ejemplo n.º 9
0
def main():
    global args, best_mIoU
    PID = os.getpid()
    args = parser.parse_args()
    prepare_seed(args.rand_seed)
    device = torch.device("cuda:" + str(args.gpus))

    if args.timestamp == 'none':
        args.timestamp = "{:}".format(
            time.strftime('%h-%d-%C_%H-%M-%s', time.gmtime(time.time())))

    switch_model = args.switch_model
    assert switch_model in ["deeplab50", "deeplab101"]

    # Log outputs
    if args.evaluate:
        args.save_dir = args.save_dir + "/GTA5-%s-evaluate"%switch_model + \
            "%s/%s"%('/'+args.resume if args.resume != 'none' else '', args.timestamp)
    else:
        args.save_dir = args.save_dir + \
            "/GTA5_512x512-{model}-LWF.stg{csg_stages}.w{csg_weight}-APool.{apool}-Aug.{augment}-chunk{chunks}-mlp{mlp}.K{csg_k}-LR{lr}.bone{factor}-epoch{epochs}-batch{batch_size}-seed{seed}".format(
                    model=switch_model,
                    csg_stages=args.csg_stages,
                    mlp=args.mlp,
                    csg_weight=args.csg,
                    apool=args.apool,
                    augment=args.augment,
                    chunks=args.chunks,
                    csg_k=args.csg_k,
                    lr="%.2E"%args.lr,
                    factor="%.1f"%args.factor,
                    epochs=args.epochs,
                    batch_size=args.batch_size,
                    seed=args.rand_seed
                    ) + \
            "%s/%s"%('/'+args.resume if args.resume != 'none' else '', args.timestamp)
    logger = prepare_logger(args)

    from config_seg import config as data_setting
    data_setting.batch_size = args.batch_size
    train_loader = get_train_loader(data_setting,
                                    GTA5,
                                    test=False,
                                    augment=args.augment)

    args.stages = [int(stage) for stage in args.csg_stages.split('.')
                   ] if len(args.csg_stages) > 0 else []
    chunks = [int(chunk) for chunk in args.chunks.split('.')
              ] if len(args.chunks) > 0 else []
    assert len(chunks) == 1 or len(chunks) == len(args.stages)
    if len(chunks) < len(args.stages):
        chunks = [chunks[0]] * len(args.stages)

    if switch_model == 'deeplab50':
        layers = [3, 4, 6, 3]
    elif switch_model == 'deeplab101':
        layers = [3, 4, 23, 3]
    model = csg_builder.CSG(deeplab,
                            get_head=None,
                            K=args.csg_k,
                            stages=args.stages,
                            chunks=chunks,
                            task='new-seg',
                            apool=args.apool,
                            mlp=args.mlp,
                            base_encoder_kwargs={
                                'num_seg_classes': args.num_classes,
                                'layers': layers
                            })

    threds = 3
    evaluator = SegEvaluator(
        Cityscapes(data_setting, 'val', None),
        args.num_classes,
        np.array([0.485, 0.456, 0.406]),
        np.array([0.229, 0.224, 0.225]),
        model.encoder_q, [
            1,
        ],
        False,
        devices=args.gpus,
        config=data_setting,
        threds=threds,
        verbose=False,
        save_path=None,
        show_image=False
    )  # just calculate mIoU, no prediction file is generated
    # verbose=False, save_path="./prediction_files", show_image=True, show_prediction=True)  # generate prediction files

    # Setup optimizer
    factor = args.factor
    sgd_in = [
        {
            'params': get_params(model.encoder_q, ["conv1"]),
            'lr': factor * args.lr
        },
        {
            'params': get_params(model.encoder_q, ["bn1"]),
            'lr': factor * args.lr
        },
        {
            'params': get_params(model.encoder_q, ["layer1"]),
            'lr': factor * args.lr
        },
        {
            'params': get_params(model.encoder_q, ["layer2"]),
            'lr': factor * args.lr
        },
        {
            'params': get_params(model.encoder_q, ["layer3"]),
            'lr': factor * args.lr
        },
        {
            'params': get_params(model.encoder_q, ["layer4"]),
            'lr': factor * args.lr
        },
        {
            'params': get_params(model.encoder_q, ["fc_new"]),
            'lr': args.lr
        },
    ]
    base_lrs = [group['lr'] for group in sgd_in]
    optimizer = SGD(sgd_in,
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)

    # Optionally resume from a checkpoint
    if args.resume != 'none':
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=lambda storage, loc: storage)
            args.start_epoch = checkpoint['epoch']
            best_mIoU = checkpoint['best_mIoU']
            msg = model.load_state_dict(checkpoint['state_dict'])
            print("resume weights: ", msg)
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=ImageClassdata> no checkpoint found at '{}'".format(
                args.resume))

    model = model.to(device)

    if args.evaluate:
        mIoU = validate(evaluator, model, -1)
        print(mIoU)
        exit(0)

    # Main training loop
    iter_max = args.epochs * len(train_loader)
    iter_stat = IterNums(iter_max)
    for epoch in range(args.start_epoch, args.epochs):
        print("<< ============== JOB (PID = %d) %s ============== >>" %
              (PID, args.save_dir))
        logger.log("Epoch: %d" % (epoch + 1))
        # train for one epoch
        train(args,
              train_loader,
              model,
              optimizer,
              base_lrs,
              iter_stat,
              epoch,
              logger,
              device,
              adjust_lr=epoch < args.epochs)

        # evaluate on validation set
        torch.cuda.empty_cache()
        mIoU = validate(evaluator, model, epoch)
        logger.writer.add_scalar("mIoU", mIoU, epoch + 1)
        logger.log("mIoU: %f" % mIoU)

        # remember best mIoU and save checkpoint
        is_best = mIoU > best_mIoU
        best_mIoU = max(mIoU, best_mIoU)
        save_checkpoint(
            args.save_dir, {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_mIoU': best_mIoU,
            }, is_best)

    logging.info('Best accuracy: {mIoU:.3f}'.format(mIoU=best_mIoU))
def load_dataset(opt):
    if opt.dataset in ['KITTI_64', 'KITTI_256']:
        from data.kitti import KITTI, KITTITest
        if not opt.inference:
            train_data = KITTI(
                data_root=opt.data_root,
                split='train',
                frame_sampling_rate=opt.frame_sampling_rate,
                video_length=opt.n_past + opt.n_future,
                hflip=True,
            )
            test_data = KITTI(
                data_root=opt.data_root,
                split='val',
                frame_sampling_rate=opt.frame_sampling_rate,
                video_length=opt.n_past + opt.n_eval,
                hflip=False,
            )
        else:
            train_data = DummyDataset()
            test_data = KITTITest(
                data_root=opt.data_root,
                split='test',
                frame_sampling_rate=opt.frame_sampling_rate,
                video_length=opt.n_past + opt.n_eval,
                hflip=False,
            )
    elif opt.dataset in ['Cityscapes_128x256']:
        from data.cityscapes import Cityscapes, CityscapesTest
        if not opt.inference:
            train_data = Cityscapes(
                data_root=opt.data_root,
                split='train',
                frame_sampling_rate=opt.frame_sampling_rate,
                video_length=opt.n_past + opt.n_future,
                hflip=True,
            )
            test_data = CityscapesTest(
                data_root=opt.data_root,
                split='val',
                frame_sampling_rate=opt.frame_sampling_rate,
                video_length=opt.n_past + opt.n_eval,
                hflip=False,
            )
        else:
            train_data = DummyDataset()
            test_data = CityscapesTest(
                data_root=opt.data_root,
                split='test',
                frame_sampling_rate=opt.frame_sampling_rate,
                video_length=opt.n_past + opt.n_eval,
                hflip=False,
            )
    elif opt.dataset in ['Pose_64', 'Pose_128']:
        from data.pose import Pose, PoseTest
        if not opt.inference:
            train_data = Pose(
                data_root=opt.data_root,
                split='train',
                frame_sampling_rate=opt.frame_sampling_rate,
                video_length=opt.n_past + opt.n_future,
                hflip=True,
            )
            test_data = PoseTest(
                data_root=opt.data_root,
                split='test',
                frame_sampling_rate=opt.frame_sampling_rate,
                video_length=opt.n_past + opt.n_eval,
                hflip=False,
            )
        else:
            train_data = DummyDataset()
            test_data = PoseTest(
                data_root=opt.data_root,
                split='test',
                frame_sampling_rate=opt.frame_sampling_rate,
                video_length=opt.n_past + opt.n_eval,
                hflip=False,
            )

    return train_data, test_data