Exemplo n.º 1
0
def test(cfg, writer, logger):
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))
    ## create dataset
    default_gpu = cfg['model']['default_gpu']
    device = torch.device(
        "cuda:{}".format(default_gpu) if torch.cuda.is_available() else 'cpu')
    datasets = create_dataset(
        cfg, writer, logger
    )  #source_train\ target_train\ source_valid\ target_valid + _loader

    model = CustomModel(cfg, writer, logger)
    running_metrics_val = runningScore(cfg['data']['target']['n_class'])
    source_running_metrics_val = runningScore(cfg['data']['target']['n_class'])
    val_loss_meter = averageMeter()
    source_val_loss_meter = averageMeter()
    time_meter = averageMeter()
    loss_fn = get_loss_function(cfg)
    path = cfg['test']['path']
    checkpoint = torch.load(path)
    model.adaptive_load_nets(model.BaseNet,
                             checkpoint['DeepLab']['model_state'])

    validation(
                model, logger, writer, datasets, device, running_metrics_val, val_loss_meter, loss_fn,\
                source_val_loss_meter, source_running_metrics_val, iters = model.iter
                )
def train(model, args):
    # model compile
    lr = 0.0005
    optimizer = Adam(lr=lr)
    model.compile(optimizer=optimizer, loss=get_loss_function())

    # data generator
    if args.frame_length == 5:
        input_generator = input_generator_fl5
    elif args.frame_length == 4:
        input_generator = input_generator_fl4
    elif args.frame_length == 3:
        input_generator = input_generator_fl3
    else:
        raise Exception(
            f'args.frame_length must be an integer between 3 and 5, but receive {args.frame_length}.'
        )
    train_datagen = input_generator(args.train_list)
    valid_datagen = input_generator(args.valid_list, val_mode=True)

    # callbacks
    now = datetime.datetime.now().strftime("%Y-%m-%d_%H%M")
    output = pathlib.Path(
        f'../output/{now}_{args.model_name}_fl{args.frame_length}_{args.memo}')
    output.mkdir(exist_ok=True, parents=True)
    cp = ModelCheckpoint(filepath=f'{output}/weights.h5',
                         monitor='val_loss',
                         save_best_only=True,
                         save_weights_only=True,
                         verbose=0,
                         mode='auto')
    logger = CSVLogger(f'{output}/history.csv')

    ## learning rate
    def step_decay(epoch):
        factor = 1
        if epoch >= 10: factor = 0.1
        if epoch >= 15: factor = 0.01
        return lr * factor

    lr_schedule = LearningRateScheduler(step_decay, verbose=1)

    # START training
    batch_size = 64
    epochs = 20
    model.fit_generator(
        generator=train_datagen.flow_from_directory(batch_size),
        steps_per_epoch=len(train_datagen.data_paths) // batch_size,
        epochs=epochs,
        initial_epoch=0,
        verbose=1,
        callbacks=[cp, logger, lr_schedule],
        validation_data=valid_datagen.flow_from_directory(batch_size),
        validation_steps=len(valid_datagen.data_paths) // batch_size,
        max_queue_size=20)
Exemplo n.º 3
0
    def __init__(self, cfg, writer, logger, use_pseudo_label=False, modal_num=3, multimodal_merger=multimodal_merger):
        self.cfg = cfg
        self.writer = writer
        self.class_numbers = 19
        self.logger = logger
        cfg_model = cfg['model']
        self.cfg_model = cfg_model
        self.best_iou = -100
        self.iter = 0
        self.nets = []
        self.split_gpu = 0
        self.default_gpu = cfg['model']['default_gpu']
        self.PredNet_Dir = None
        self.valid_classes = cfg['training']['valid_classes']
        self.G_train = True
        self.cls_feature_weight = cfg['training']['cls_feature_weight']
        self.use_pseudo_label = use_pseudo_label
        self.modal_num = modal_num

        # cluster vectors & cuda initialization
        self.objective_vectors_group = torch.zeros(self.modal_num + 1, 19, 256).cuda()
        self.objective_vectors_num_group = torch.zeros(self.modal_num + 1, 19).cuda()
        self.objective_vectors_dis_group = torch.zeros(self.modal_num + 1, 19, 19).cuda()
        self.class_threshold_group = torch.full([self.modal_num + 1, 19], 0.6).cuda()

        self.disc_T = torch.FloatTensor([0.0]).cuda()

        #self.metrics = CustomMetrics(self.class_numbers)
        self.metrics = CustomMetrics(self.class_numbers, modal_num=self.modal_num, model=self)

        # multimodal / multi-branch merger
        self.multimodal_merger = multimodal_merger

        bn = cfg_model['bn']
        if bn == 'sync_bn':
            BatchNorm = SynchronizedBatchNorm2d
        elif bn == 'bn':
            BatchNorm = nn.BatchNorm2d
        elif bn == 'gn':
            BatchNorm = nn.GroupNorm
        else:
            raise NotImplementedError('batch norm choice {} is not implemented'.format(bn))

        if True:
            self.PredNet = DeepLab(
                    num_classes=19,
                    backbone=cfg_model['basenet']['version'],
                    output_stride=16,
                    bn=cfg_model['bn'],
                    freeze_bn=True,
                    modal_num=self.modal_num
                    ).cuda()
            self.load_PredNet(cfg, writer, logger, dir=None, net=self.PredNet)
            self.PredNet_DP = self.init_device(self.PredNet, gpu_id=self.default_gpu, whether_DP=True) 
            self.PredNet.eval()
            self.PredNet_num = 0

            self.PredDnet = FCDiscriminator(inplanes=19)
            self.load_PredDnet(cfg, writer, logger, dir=None, net=self.PredDnet)
            self.PredDnet_DP = self.init_device(self.PredDnet, gpu_id=self.default_gpu, whether_DP=True)
            self.PredDnet.eval()

        self.BaseNet = DeepLab(
                            num_classes=19,
                            backbone=cfg_model['basenet']['version'],
                            output_stride=16,
                            bn=cfg_model['bn'],
                            freeze_bn=True, 
                            modal_num=self.modal_num
                            )

        logger.info('the backbone is {}'.format(cfg_model['basenet']['version']))

        self.BaseNet_DP = self.init_device(self.BaseNet, gpu_id=self.default_gpu, whether_DP=True)
        self.nets.extend([self.BaseNet])
        self.nets_DP = [self.BaseNet_DP]

        # Discriminator
        self.SOURCE_LABEL = 0
        self.TARGET_LABEL = 1
        self.DNets = []
        self.DNets_DP = []
        for _ in range(self.modal_num+1):
            _net_d = FCDiscriminator(inplanes=19)
            self.DNets.append(_net_d)
            _net_d_DP = self.init_device(_net_d, gpu_id=self.default_gpu, whether_DP=True)
            self.DNets_DP.append(_net_d_DP)

        self.nets.extend(self.DNets)
        self.nets_DP.extend(self.DNets_DP)

        self.optimizers = []
        self.schedulers = []        

        optimizer_cls = torch.optim.SGD
        optimizer_params = {k:v for k, v in cfg['training']['optimizer'].items() 
                            if k != 'name'}

        optimizer_cls_D = torch.optim.Adam
        optimizer_params_D = {k:v for k, v in cfg['training']['optimizer_D'].items() 
                            if k != 'name'}

        if False:
            self.BaseOpti = optimizer_cls(self.BaseNet.parameters(), **optimizer_params)
        else:
            self.BaseOpti = optimizer_cls(self.BaseNet.optim_parameters(cfg['training']['optimizer']['lr']), **optimizer_params)

        self.optimizers.extend([self.BaseOpti])

        self.DiscOptis = []
        for _d_net in self.DNets: 
            self.DiscOptis.append(
                optimizer_cls_D(_d_net.parameters(), **optimizer_params_D)
            )
        self.optimizers.extend(self.DiscOptis)

        self.schedulers = []        

        if False:
            self.BaseSchedule = get_scheduler(self.BaseOpti, cfg['training']['lr_schedule'])
            self.schedulers.extend([self.BaseSchedule])
        else:
            """BaseSchedule detail see FUNC: scheduler_step()"""
            self.learning_rate = cfg['training']['optimizer']['lr']
            self.gamma = cfg['training']['lr_schedule']['gamma']
            self.num_steps = cfg['training']['lr_schedule']['max_iter']
            self._BaseSchedule_nouse = get_scheduler(self.BaseOpti, cfg['training']['lr_schedule'])
            self.schedulers.extend([self._BaseSchedule_nouse])

        self.DiscSchedules = []
        for _disc_opt in self.DiscOptis:
            self.DiscSchedules.append(
                get_scheduler(_disc_opt, cfg['training']['lr_schedule'])
            )
        self.schedulers.extend(self.DiscSchedules)
        self.setup(cfg, writer, logger)

        self.adv_source_label = 0
        self.adv_target_label = 1
        self.bceloss = nn.BCEWithLogitsLoss(reduce=False)
        self.loss_fn = get_loss_function(cfg)
        pseudo_cfg = copy.deepcopy(cfg)
        pseudo_cfg['training']['loss']['name'] = 'cross_entropy4d'
        self.pseudo_loss_fn = get_loss_function(pseudo_cfg)
        self.mseloss = nn.MSELoss()
        self.l1loss = nn.L1Loss()
        self.smoothloss = nn.SmoothL1Loss()
        self.triplet_loss = nn.TripletMarginLoss()
        self.kl_distance = nn.KLDivLoss(reduction='none')
Exemplo n.º 4
0
def train(cfg):
    
    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['train_split'],
        #img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
        augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        #img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
        )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'], 
                                  num_workers=cfg['training']['n_workers'], 
                                  shuffle=True)

    valloader = data.DataLoader(v_loader, 
                                batch_size=cfg['training']['batch_size'], 
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)

    model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k:v for k, v in cfg['training']['optimizer'].items() 
                        if k != 'name'}

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
 
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            print("=====>",
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg['training']['resume'], checkpoint["epoch"]
                )
            )
        else:
            print("=====>","No checkpoint found at '{}'".format(cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg['training']['train_iters'] and flag:
        for (images, labels) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            optimizer.step()
            
            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(i + 1,
                                           cfg['training']['train_iters'], 
                                           loss.item(),
                                           time_meter.avg / cfg['training']['batch_size'])

                print(print_str)
                time_meter.reset()

            if (i + 1) % cfg['training']['val_interval'] == 0 or \
               (i + 1) == cfg['training']['train_iters']:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        val_loss = loss_fn(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()


                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())


                print("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k,':',v)

                for k, v in class_iou.items():
                    print('{}: {}'.format(k, v))

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join('./checkpoint',
                                             "{}_{}_best_model.pkl".format(
                                                 cfg['model']['arch'],
                                                 cfg['data']['dataset']))
                    print("saving···")
                    torch.save(state, save_path)

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break
Exemplo n.º 5
0
def train(cfg, logger):

    # Setup Seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup Device
    device = torch.device("cuda:{}".format(cfg["training"]["gpu_idx"])
                          if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        split=cfg["data"]["train_split"],
    )

    v_loader = data_loader(
        data_path,
        split=cfg["data"]["val_split"],
    )

    n_classes = t_loader.n_classes
    n_val = len(v_loader.files['val'])

    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
    )

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes, n_val)

    # Setup Model
    model = get_model(cfg["model"], n_classes).to(device)
    model = torch.nn.DataParallel(model,
                                  device_ids=[cfg["training"]["gpu_idx"]])

    # Setup Optimizer, lr_scheduler and Loss Function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    # Resume Trained Model
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    # Start Training
    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    start_iter = 0
    best_dice = -100.0
    i = start_iter
    flag = True

    while i <= cfg["training"]["train_iters"] and flag:
        for (images, labels, img_name) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            # print train loss
            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss.item(),
                    time_meter.avg / cfg["training"]["batch_size"],
                )

                print(print_str)
                logger.info(print_str)
                time_meter.reset()

            # validation
            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val, labels_val,
                                img_name_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        val_loss = loss_fn(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred, i_val)
                        val_loss_meter.update(val_loss.item())

                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                # print val metrics
                score, class_dice = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info("{}: {}".format(k, v))

                for k, v in class_dice.items():
                    logger.info("{}: {}".format(k, v))

                val_loss_meter.reset()
                running_metrics_val.reset()

                # save model
                if score["Dice : \t"] >= best_dice:
                    best_dice = score["Dice : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_dice": best_dice,
                    }
                    save_path = os.path.join(
                        cfg["training"]["model_dir"],
                        "{}_{}.pkl".format(cfg["model"]["arch"],
                                           cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
Exemplo n.º 6
0
def CAC(cfg, writer, logger):
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))
    ## create dataset
    default_gpu = cfg['model']['default_gpu']
    device = torch.device(
        "cuda:{}".format(default_gpu) if torch.cuda.is_available() else 'cpu')
    datasets = create_dataset(
        cfg, writer, logger
    )  #source_train\ target_train\ source_valid\ target_valid + _loader

    model = CustomModel(cfg, writer, logger)

    # Setup Metrics
    running_metrics_val = RunningScore(cfg['data']['target']['n_class'])
    source_running_metrics_val = RunningScore(cfg['data']['target']['n_class'])
    val_loss_meter = AverageMeter()
    source_val_loss_meter = AverageMeter()
    time_meter = AverageMeter()
    loss_fn = get_loss_function(cfg)
    flag_train = True

    epoches = cfg['training']['epoches']

    source_train_loader = datasets.source_train_loader
    target_train_loader = datasets.target_train_loader
    logger.info('source train batchsize is {}'.format(
        source_train_loader.args.get('batch_size')))
    print('source train batchsize is {}'.format(
        source_train_loader.args.get('batch_size')))
    logger.info('target train batchsize is {}'.format(
        target_train_loader.batch_size))
    print('target train batchsize is {}'.format(
        target_train_loader.batch_size))

    val_loader = None
    if cfg.get('valset') == 'gta5':
        val_loader = datasets.source_valid_loader
        logger.info('valset is gta5')
        print('valset is gta5')
    else:
        val_loader = datasets.target_valid_loader
        logger.info('valset is cityscapes')
        print('valset is cityscapes')
    logger.info('val batchsize is {}'.format(val_loader.batch_size))
    print('val batchsize is {}'.format(val_loader.batch_size))

    # load category anchors
    # objective_vectors = torch.load('category_anchors')
    # model.objective_vectors = objective_vectors['objective_vectors']
    # model.objective_vectors_num = objective_vectors['objective_num']
    class_features = Class_Features(numbers=19)

    # begin training
    model.iter = 0
    for epoch in range(epoches):
        if not flag_train:
            break
        if model.iter > cfg['training']['train_iters']:
            break

        # monitoring the accuracy and recall of CAG-based PLA and probability-based PLA

        for (target_image, target_label,
             target_img_name) in datasets.target_train_loader:
            model.iter += 1
            i = model.iter
            if i > cfg['training']['train_iters']:
                break
            images, labels, source_img_name = datasets.source_train_loader.next(
            )
            start_ts = time.time()

            images = images.to(device)
            labels = labels.to(device)
            target_image = target_image.to(device)
            target_label = target_label.to(device)
            model.scheduler_step()
            model.train(logger=logger)
            if cfg['training'].get('freeze_bn') == True:
                model.freeze_bn_apply()
            model.optimizer_zero_grad()
            if model.PredNet.training:
                model.PredNet.eval()
            with torch.no_grad():
                _, _, feat_cls, output = model.PredNet_Forward(images)
                batch, w, h = labels.size()
                newlabels = labels.reshape([batch, 1, w, h]).float()
                newlabels = F.interpolate(newlabels,
                                          size=feat_cls.size()[2:],
                                          mode='nearest')
                vectors, ids = class_features.calculate_mean_vector(
                    feat_cls, output, newlabels, model)
                for t in range(len(ids)):
                    model.update_objective_SingleVector(
                        ids[t], vectors[t].detach().cpu().numpy(), 'mean')

            time_meter.update(time.time() - start_ts)
            if model.iter % 20 == 0:
                print("Iter [{:d}] Time {:.4f}".format(model.iter,
                                                       time_meter.avg))

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break
    save_path = os.path.join(
        writer.file_writer.get_logdir(), "anchors_on_{}_from_{}".format(
            cfg['data']['source']['name'],
            cfg['model']['arch'],
        ))
    torch.save(model.objective_vectors, save_path)
Exemplo n.º 7
0
def train(cfg, writer, logger):

    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Dataloader
    trainloader = get_loader(cfg, "train")
    valloader = get_loader(cfg, "val")

    n_classes = cfg["data"]["n_classes"]
    n_channels = cfg["data"]["channels"]

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg, n_classes, n_channels).to(device)
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.module.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    # fig = plt.figure()
    # plt.rcParams['xtick.major.pad'] = '15'
    # fig.show()
    # fig.canvas.draw()

    while i <= cfg["training"]["train_iters"] and flag:
        for (images, labels) in trainloader:
            i += 1
            start_ts = time.time()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
            # plot_grad_flow(model.named_parameters(), fig)

            # zero mean conv for layer 1 of dsm encoder
            optimizer.step()
            scheduler.step()
            # m = model._modules['module'].encoderDSM._modules['0']._modules['0']
            # model._modules['module'].encoderDSM._modules['0']._modules['0'].weight = m.weight - torch.mean(m.weight)
            model = zero_mean(model, all=False)

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss.item(),
                    time_meter.avg / cfg["training"]["batch_size"],
                )

                print(print_str)
                logger.info(print_str)
                writer.add_scalar("loss/train_loss", loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val,
                                labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        val_loss = loss_fn(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()
                        # plt.imshow(v_loader.decode_segmap(gt[0,:,:]))
                        # plt.imshow(v_loader.decode_segmap(pred[0, :, :]))
                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    #print(k, v)
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
Exemplo n.º 8
0
def train(cfg, writer, logger):
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))
    ## create dataset
    default_gpu = cfg['model']['default_gpu']
    device = torch.device(
        "cuda:{}".format(default_gpu) if torch.cuda.is_available() else 'cpu')
    datasets = create_dataset(cfg, writer, logger)

    use_pseudo_label = False
    model = CustomModel(cfg, writer, logger, use_pseudo_label, modal_num=3)

    # Setup Metrics
    running_metrics_val = runningScore(cfg['data']['target']['n_class'])
    source_running_metrics_val = runningScore(cfg['data']['target']['n_class'])
    val_loss_meter = averageMeter()
    source_val_loss_meter = averageMeter()
    time_meter = averageMeter()
    loss_fn = get_loss_function(cfg)
    flag_train = True

    epoches = cfg['training']['epoches']

    source_train_loader = datasets.source_train_loader
    target_train_loader = datasets.target_train_loader

    logger.info('source train batchsize is {}'.format(
        source_train_loader.args.get('batch_size')))
    print('source train batchsize is {}'.format(
        source_train_loader.args.get('batch_size')))
    logger.info('target train batchsize is {}'.format(
        target_train_loader.batch_size))
    print('target train batchsize is {}'.format(
        target_train_loader.batch_size))

    val_loader = None
    if cfg.get('valset') == 'gta5':
        val_loader = datasets.source_valid_loader
        logger.info('valset is gta5')
        print('valset is gta5')
    else:
        val_loader = datasets.target_valid_loader
        logger.info('valset is cityscapes')
        print('valset is cityscapes')
    logger.info('val batchsize is {}'.format(val_loader.batch_size))
    print('val batchsize is {}'.format(val_loader.batch_size))

    # load category anchors
    """
    objective_vectors = torch.load('category_anchors')
    model.objective_vectors = objective_vectors['objective_vectors']
    model.objective_vectors_num = objective_vectors['objective_num']
    """

    # begin training
    model.iter = 0
    for epoch in range(epoches):
        if not flag_train:
            break
        if model.iter > cfg['training']['train_iters']:
            break

        if use_pseudo_label:
            # monitoring the accuracy and recall of CAG-based PLA and probability-based PLA
            score_cl, _ = model.metrics.running_metrics_val_clusters.get_scores(
            )
            print('clus_IoU: {}'.format(score_cl["Mean IoU : \t"]))

            logger.info('clus_IoU: {}'.format(score_cl["Mean IoU : \t"]))
            logger.info('clus_Recall: {}'.format(
                model.metrics.calc_mean_Clu_recall()))
            logger.info(model.metrics.classes_recall_clu[:, 0] /
                        model.metrics.classes_recall_clu[:, 1])
            logger.info('clus_Acc: {}'.format(
                np.mean(model.metrics.classes_recall_clu[:, 0] /
                        model.metrics.classes_recall_clu[:, 1])))
            logger.info(model.metrics.classes_recall_clu[:, 0] /
                        model.metrics.classes_recall_clu[:, 2])

            score_cl, _ = model.metrics.running_metrics_val_threshold.get_scores(
            )
            logger.info('thr_IoU: {}'.format(score_cl["Mean IoU : \t"]))
            logger.info('thr_Recall: {}'.format(
                model.metrics.calc_mean_Thr_recall()))
            logger.info(model.metrics.classes_recall_thr[:, 0] /
                        model.metrics.classes_recall_thr[:, 1])
            logger.info('thr_Acc: {}'.format(
                np.mean(model.metrics.classes_recall_thr[:, 0] /
                        model.metrics.classes_recall_thr[:, 1])))
            logger.info(model.metrics.classes_recall_thr[:, 0] /
                        model.metrics.classes_recall_thr[:, 2])
        model.metrics.reset()

        for (target_image, target_label,
             target_img_name) in datasets.target_train_loader:
            model.iter += 1
            i = model.iter
            if i > cfg['training']['train_iters']:
                break
            source_batchsize = cfg['data']['source']['batch_size']
            # load source data
            images, labels, source_img_name = datasets.source_train_loader.next(
            )
            start_ts = time.time()
            images = images.to(device)
            labels = labels.to(device)
            # load target data
            target_image = target_image.to(device)
            target_label = target_label.to(device)
            #model.scheduler_step()
            model.train(logger=logger)
            if cfg['training'].get('freeze_bn') == True:
                model.freeze_bn_apply()
            model.optimizer_zerograd()
            # Switch on modals
            source_modal_ids = []
            for _img_name in source_img_name:
                if 'gtav2cityscapes' in _img_name:
                    source_modal_ids.append(0)
                elif 'gtav2cityfoggy' in _img_name:
                    source_modal_ids.append(1)
                elif 'gtav2cityrain' in _img_name:
                    source_modal_ids.append(2)
                else:
                    assert False, "[ERROR] unknown image source, neither gtav2cityscapes, gtav2cityfoggy!"

            target_modal_ids = []
            for _img_name in target_img_name:
                if 'Cityscapes_foggy' in _img_name:
                    target_modal_ids.append(1)
                elif 'Cityscapes_rain' in _img_name:
                    target_modal_ids.append(2)
                else:
                    target_modal_ids.append(0)

            loss, loss_cls_L2, loss_pseudo = model.step(
                images, labels, source_modal_ids, target_image, target_label,
                target_modal_ids, use_pseudo_label)
            # scheduler step
            model.scheduler_step()
            if loss_cls_L2 > 10:
                logger.info('loss_cls_l2 abnormal!!')

            time_meter.update(time.time() - start_ts)
            if (i + 1) % cfg['training']['print_interval'] == 0:
                unchanged_cls_num = 0
                if use_pseudo_label:
                    fmt_str = "Epoches [{:d}/{:d}] Iter [{:d}/{:d}]  Loss: {:.4f}  Loss_L2: {:.4f}  Loss_pseudo: {:.4f}  Time/Image: {:.4f}"
                else:
                    fmt_str = "Epoches [{:d}/{:d}] Iter [{:d}/{:d}]  Loss_GTA: {:.4f}  Loss_adv: {:.4f}  Loss_D: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    epoch + 1, epoches, i + 1, cfg['training']['train_iters'],
                    loss.item(), loss_cls_L2.item(), loss_pseudo.item(),
                    time_meter.avg / cfg['data']['source']['batch_size'])

                print(print_str)
                logger.info(print_str)
                logger.info(
                    'unchanged number of objective class vector: {}'.format(
                        unchanged_cls_num))
                if use_pseudo_label:
                    loss_names = [
                        'train_loss', 'train_L2Loss', 'train_pseudoLoss'
                    ]
                else:
                    loss_names = [
                        'train_loss_GTA', 'train_loss_adv', 'train_loss_D'
                    ]
                writer.add_scalar('loss/{}'.format(loss_names[0]), loss.item(),
                                  i + 1)
                writer.add_scalar('loss/{}'.format(loss_names[1]),
                                  loss_cls_L2.item(), i + 1)
                writer.add_scalar('loss/{}'.format(loss_names[2]),
                                  loss_pseudo.item(), i + 1)
                time_meter.reset()

                if use_pseudo_label:
                    score_cl, _ = model.metrics.running_metrics_val_clusters.get_scores(
                    )
                    logger.info('clus_IoU: {}'.format(
                        score_cl["Mean IoU : \t"]))
                    logger.info('clus_Recall: {}'.format(
                        model.metrics.calc_mean_Clu_recall()))
                    logger.info('clus_Acc: {}'.format(
                        np.mean(model.metrics.classes_recall_clu[:, 0] /
                                model.metrics.classes_recall_clu[:, 2])))

                    score_cl, _ = model.metrics.running_metrics_val_threshold.get_scores(
                    )
                    logger.info('thr_IoU: {}'.format(
                        score_cl["Mean IoU : \t"]))
                    logger.info('thr_Recall: {}'.format(
                        model.metrics.calc_mean_Thr_recall()))
                    logger.info('thr_Acc: {}'.format(
                        np.mean(model.metrics.classes_recall_thr[:, 0] /
                                model.metrics.classes_recall_thr[:, 2])))

            # evaluation
            if (i + 1) % cfg['training']['val_interval'] == 0 or \
                (i + 1) == cfg['training']['train_iters']:
                validation(
                    model, logger, writer, datasets, device, running_metrics_val, val_loss_meter, loss_fn,\
                    source_val_loss_meter, source_running_metrics_val, iters = model.iter
                    )
                torch.cuda.empty_cache()
                logger.info('Best iou until now is {}'.format(model.best_iou))
            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break
Exemplo n.º 9
0
    def __init__(self, cfg, writer, logger):
        # super(CustomModel, self).__init__()
        self.cfg = cfg
        self.writer = writer
        self.class_numbers = 19
        self.logger = logger
        cfg_model = cfg['model']
        self.cfg_model = cfg_model
        self.best_iou = -100
        self.iter = 0
        self.nets = []
        self.split_gpu = 0
        self.default_gpu = cfg['model']['default_gpu']
        self.PredNet_Dir = None
        self.valid_classes = cfg['training']['valid_classes']
        self.G_train = True
        self.objective_vectors = np.zeros([19, 256])
        self.objective_vectors_num = np.zeros([19])
        self.objective_vectors_dis = np.zeros([19, 19])
        self.class_threshold = np.zeros(self.class_numbers)
        self.class_threshold = np.full([19], 0.95)
        self.metrics = CustomMetrics(self.class_numbers)
        self.cls_feature_weight = cfg['training']['cls_feature_weight']

        bn = cfg_model['bn']
        if bn == 'sync_bn':
            BatchNorm = SynchronizedBatchNorm2d
        # elif bn == 'sync_abn':
        #     BatchNorm = InPlaceABNSync
        elif bn == 'bn':
            BatchNorm = nn.BatchNorm2d
        # elif bn == 'abn':
        #     BatchNorm = InPlaceABN
        elif bn == 'gn':
            BatchNorm = nn.GroupNorm
        else:
            raise NotImplementedError(
                'batch norm choice {} is not implemented'.format(bn))
        self.PredNet = DeepLab(
            num_classes=19,
            backbone=cfg_model['basenet']['version'],
            output_stride=16,
            bn=cfg_model['bn'],
            freeze_bn=True,
        ).cuda()
        self.load_PredNet(cfg, writer, logger, dir=None, net=self.PredNet)
        self.PredNet_DP = self.init_device(self.PredNet,
                                           gpu_id=self.default_gpu,
                                           whether_DP=True)
        self.PredNet.eval()
        self.PredNet_num = 0

        self.BaseNet = DeepLab(
            num_classes=19,
            backbone=cfg_model['basenet']['version'],
            output_stride=16,
            bn=cfg_model['bn'],
            freeze_bn=False,
        )

        logger.info('the backbone is {}'.format(
            cfg_model['basenet']['version']))

        self.BaseNet_DP = self.init_device(self.BaseNet,
                                           gpu_id=self.default_gpu,
                                           whether_DP=True)
        self.nets.extend([self.BaseNet])
        self.nets_DP = [self.BaseNet_DP]

        self.optimizers = []
        self.schedulers = []
        # optimizer_cls = get_optimizer(cfg)
        optimizer_cls = torch.optim.SGD
        optimizer_params = {
            k: v
            for k, v in cfg['training']['optimizer'].items() if k != 'name'
        }
        # optimizer_cls_D = torch.optim.SGD
        # optimizer_params_D = {k:v for k, v in cfg['training']['optimizer_D'].items()
        #                     if k != 'name'}
        self.BaseOpti = optimizer_cls(self.BaseNet.parameters(),
                                      **optimizer_params)
        self.optimizers.extend([self.BaseOpti])

        self.BaseSchedule = get_scheduler(self.BaseOpti,
                                          cfg['training']['lr_schedule'])
        self.schedulers.extend([self.BaseSchedule])
        self.setup(cfg, writer, logger)

        self.adv_source_label = 0
        self.adv_target_label = 1
        self.bceloss = nn.BCEWithLogitsLoss(size_average=True)
        self.loss_fn = get_loss_function(cfg)
        self.mseloss = nn.MSELoss()
        self.l1loss = nn.L1Loss()
        self.smoothloss = nn.SmoothL1Loss()
        self.triplet_loss = nn.TripletMarginLoss()
Exemplo n.º 10
0
"""
import sys
sys.dont_write_bytecode = True

from config import base_config

from optimizer import get_optimizer
from loss import get_loss_function
from model import get_model
from metrics import get_metrics_lst
from callback import get_callbacks
from trainer import Trainer
from data import get_datasets

if __name__ == "__main__":

    config = base_config()
    config.METRICS_LST = get_metrics_lst()
    config.OPTIMIZER = get_optimizer()
    config.LOSS_FUNC = get_loss_function()
    config.CALLBACK_LST = get_callbacks(config)

    config.display()

    model = get_model(config)
    datasets = get_datasets(config)
    trainer = Trainer(datasets, model, config)

    trainer._compile()
    trainer.train()
Exemplo n.º 11
0
def train(cfg, writer, logger):
    init_random()

    device = torch.device("cuda:{}".format(cfg['model']['default_gpu'])
                          if torch.cuda.is_available() else 'cpu')

    # create dataSet
    data_sets = create_dataset(cfg, writer, logger)  # source_train\ target_train\ source_valid\ target_valid + _loader
    if cfg.get('valset') == 'gta5':
        val_loader = data_sets.source_valid_loader
    else:
        val_loader = data_sets.target_valid_loader
    logger.info('source train batchsize is {}'.format(data_sets.source_train_loader.args.get('batch_size')))
    print('source train batchsize is {}'.format(data_sets.source_train_loader.args.get('batch_size')))
    logger.info('target train batchsize is {}'.format(data_sets.target_train_loader.batch_size))
    print('target train batchsize is {}'.format(data_sets.target_train_loader.batch_size))
    logger.info('valset is {}'.format(cfg.get('valset')))
    print('val_set is {}'.format(cfg.get('valset')))
    logger.info('val batch_size is {}'.format(val_loader.batch_size))
    print('val batch_size is {}'.format(val_loader.batch_size))

    # create model
    model = CustomModel(cfg, writer, logger)

    # LOSS function
    loss_fn = get_loss_function(cfg)

    # load category anchors
    objective_vectors = torch.load('category_anchors')
    model.objective_vectors = objective_vectors['objective_vectors']
    model.objective_vectors_num = objective_vectors['objective_num']

    # Setup Metrics
    running_metrics_val = RunningScore(cfg['data']['target']['n_class'])
    source_running_metrics_val = RunningScore(cfg['data']['source']['n_class'])
    val_loss_meter, source_val_loss_meter = AverageMeter(), AverageMeter()
    time_meter = AverageMeter()

    # begin training
    model.iter = 0
    epochs = cfg['training']['epochs']
    for epoch in tqdm(range(epochs)):
        if model.iter > cfg['training']['train_iters']:
            break

        for (target_image, target_label, target_img_name) in tqdm(data_sets.target_train_loader):
            start_ts = time.time()
            model.iter += 1
            if model.iter > cfg['training']['train_iters']:
                break

            ############################
            # train on source & target #
            ############################
            # get data
            images, labels, source_img_name = data_sets.source_train_loader.next()
            images, labels = images.to(device), labels.to(device)
            target_image, target_label = target_image.to(device), target_label.to(device)

            # init model
            model.train(logger=logger)
            if cfg['training'].get('freeze_bn'):
                model.freeze_bn_apply()
            model.optimizer_zero_grad()

            # train for one batch
            loss, loss_cls_L2, loss_pseudo = model.step(images, labels, target_image, target_label)
            model.scheduler_step()

            if loss_cls_L2 > 10:
                logger.info('loss_cls_l2 abnormal!!')

            # print
            time_meter.update(time.time() - start_ts)
            if (model.iter + 1) % cfg['training']['print_interval'] == 0:
                unchanged_cls_num = 0
                fmt_str = "Epoches [{:d}/{:d}] Iter [{:d}/{:d}]  Loss: {:.4f} " \
                          "Loss_cls_L2: {:.4f}  Loss_pseudo: {:.4f}  Time/Image: {:.4f} "
                print_str = fmt_str.format(epoch + 1, epochs, model.iter + 1, cfg['training']['train_iters'],
                                           loss.item(), loss_cls_L2, loss_pseudo,
                                           time_meter.avg / cfg['data']['source']['batch_size'])

                print(print_str)
                logger.info(print_str)
                logger.info('unchanged number of objective class vector: {}'.format(unchanged_cls_num))
                writer.add_scalar('loss/train_loss', loss.item(), model.iter + 1)
                writer.add_scalar('loss/train_cls_L2Loss', loss_cls_L2, model.iter + 1)
                writer.add_scalar('loss/train_pseudoLoss', loss_pseudo, model.iter + 1)
                time_meter.reset()

                score_cl, _ = model.metrics.running_metrics_val_clusters.get_scores()
                logger.info('clus_IoU: {}'.format(score_cl["Mean IoU : \t"]))
                logger.info('clus_Recall: {}'.format(model.metrics.calc_mean_Clu_recall()))
                logger.info('clus_Acc: {}'.format(
                    np.mean(model.metrics.classes_recall_clu[:, 0] / model.metrics.classes_recall_clu[:, 2])))

                score_cl, _ = model.metrics.running_metrics_val_threshold.get_scores()
                logger.info('thr_IoU: {}'.format(score_cl["Mean IoU : \t"]))
                logger.info('thr_Recall: {}'.format(model.metrics.calc_mean_Thr_recall()))
                logger.info('thr_Acc: {}'.format(
                    np.mean(model.metrics.classes_recall_thr[:, 0] / model.metrics.classes_recall_thr[:, 2])))

            # evaluation
            if (model.iter + 1) % cfg['training']['val_interval'] == 0 or \
                    (model.iter + 1) == cfg['training']['train_iters']:
                validation(model, logger, writer, data_sets, device, running_metrics_val, val_loss_meter, loss_fn,
                           source_val_loss_meter, source_running_metrics_val, iters=model.iter)

                torch.cuda.empty_cache()
                logger.info('Best iou until now is {}'.format(model.best_iou))

            # monitoring the accuracy and recall of CAG-based PLA and probability-based PLA
            monitor(model)

            model.metrics.reset()
Exemplo n.º 12
0
def train(cfg, writer, logger, logdir):
    cudnn.benchmark = True

    # Setup seeds
    if cfg['training']['seed'] is not None:
        logger.info("Using seed {}".format(seed))
        torch.manual_seed(cfg.get("seed", cfg['training']['seed']))
        torch.cuda.manual_seed(cfg.get("seed", cfg['training']['seed']))
        np.random.seed(cfg.get("seed", cfg['training']['seed']))
        random.seed(cfg.get("seed", cfg['training']['seed']))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["name"])

    tloader_params = {k: v for k, v in cfg["data"]["train"].items()}
    tloader_params.update({'root': cfg["data"]["root"]})

    vloader_params = {k: v for k, v in cfg["data"]["val"].items()}
    vloader_params.update({'root': cfg["data"]["root"]})

    t_loader = data_loader(**tloader_params)
    v_loader = data_loader(**vloader_params)

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
    )

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"])

    # Setup Model
    model = get_model(cfg["model"]["arch"], num_classes=n_classes).to(device)
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }
    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    loss_type = cfg["training"]["loss"]["name"]
    if cfg["training"]["loss"][loss_type] is not None:
        loss_params = {
            k: v
            for k, v in cfg["training"]["loss"][loss_type].items()
        }
    else:
        loss_params = {}
    if cfg['training']['reweight']:
        per_cls_weights = t_loader.get_balanced_weight()
        per_cls_weights = torch.FloatTensor(per_cls_weights).to(
            device) if per_cls_weights is not None else None
        loss_params["weight"] = per_cls_weights
    loss_fn = get_loss_function(cfg, **loss_params)
    logger.info("Using loss {}".format(loss_fn))

    start_epoch = 0
    best_acc1 = -100.0
    # Resume pre-trained model
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["state_dict"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            best_acc1 = checkpoint["best_acc1"]
            start_epoch = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    ##================== Training ============================
    for epoch in range(start_epoch, cfg['training']['train_epoch']):
        batch_time = AverageMeter('Time', ':6.3f')
        data_time = AverageMeter('Data', ':6.3f')
        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')

        end = time.time()
        model.train()
        adjust_learning_rate(optimizer, epoch, cfg)
        for i, (input, target) in enumerate(trainloader):
            data_time.update(time.time() - end)
            input = input.cuda()
            target = target.cuda()

            logit = model(input)
            loss = loss_fn(logit, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # measure accuracy and record loss
            acc1, acc5 = accuracy(logit, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1[0], input.size(0))
            top5.update(acc5[0], input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % cfg["training"]["print_interval"] == 0:
                output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
                          'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                          'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                          'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                          'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                          'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                              epoch,
                              i,
                              len(trainloader),
                              batch_time=batch_time,
                              data_time=data_time,
                              loss=losses,
                              top1=top1,
                              top5=top5,
                              lr=optimizer.param_groups[-1]['lr']))  # TODO
                print(output)
                logger.info(output + '\n')

        writer.add_scalar('train/loss', losses.avg, epoch)
        writer.add_scalar('train/acc_top1', top1.avg, epoch)
        writer.add_scalar('train/acc_top5', top5.avg, epoch)
        writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)

        ##================== Evaluation ============================
        eval_top1 = AverageMeter('Acc@1', ':6.2f')
        eval_top5 = AverageMeter('Acc@5', ':6.2f')
        model.eval()
        with torch.no_grad():
            for i, (input, target) in enumerate(valloader):

                input = input.to(device)
                target = target.to(device)
                logit = model(input)
                acc1, acc5 = accuracy(logit, target, topk=(1, 5))
                eval_top1.update(acc1[0], input.size(0))
                eval_top5.update(acc5[0], input.size(0))

                if i % cfg["training"]["print_interval"] == 0:
                    output = ('Test: [{0}][{1}/{2}], lr: {lr:.5f}\t'
                              'Prec@1 {top1.avg:.3f}\t'
                              'Prec@5 {top5.avg:.3f}'.format(
                                  epoch,
                                  i,
                                  len(valloader),
                                  top1=eval_top1,
                                  top5=eval_top5,
                                  lr=optimizer.param_groups[-1]['lr']))  # TODO
                    print(output)

        output = (
            'validation Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.
            format(top1=eval_top1, top5=eval_top5))
        logger.info(output + '\n')
        print(output)
        writer.add_scalar('val/acc_top1', eval_top1.avg, epoch)
        writer.add_scalar('val/acc_top5', eval_top5.avg, epoch)

        is_best = eval_top1.avg > best_acc1
        best_acc1 = max(eval_top1.avg, best_acc1)

        output_best = 'Best Prec@1: %.3f' % (best_acc1)
        logger.info(output_best + '\n')
        print(output_best)
        writer.add_scalar('val/acc_best_top1', best_acc1, epoch)

        save_checkpoint(
            logdir, {
                'epoch': epoch,
                'arch': cfg["model"]["arch"],
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best)