示例#1
0
    def setup_clients(self):
        self.logger.info("############setup_clients (START)#############")
        args_datadir = "./data/cifar10"
        for client_idx in range(self.args.client_number):
            self.logger.info("######client idx = " + str(client_idx))

            dataidxs = self.net_dataidx_map[client_idx]
            local_sample_number = len(dataidxs)

            split = int(np.floor(0.5 * local_sample_number))  # split index
            train_idxs = dataidxs[0:split]
            test_idxs = dataidxs[split:local_sample_number]

            train_local, _ = get_dataloader(self.args.dataset, args_datadir, self.args.batch_size, self.args.batch_size,
                                            train_idxs)
            self.logger.info("client_idx = %d, batch_num_train_local = %d" % (client_idx, len(train_local)))

            test_local, _ = get_dataloader(self.args.dataset, args_datadir, self.args.batch_size, self.args.batch_size,
                                           test_idxs)
            self.logger.info("client_idx = %d, batch_num_test_local = %d" % (client_idx, len(test_local)))

            self.logger.info('n_sample: %d' % local_sample_number)
            self.logger.info('n_training: %d' % len(train_local))
            self.logger.info('n_test: %d' % len(test_local))

            c = Client(client_idx, train_local, test_local, local_sample_number, self.args, self.logger,
                       self.device,
                       self.is_wandb_used)
            self.client_list.append(c)

        self.logger.info("############setup_clients (END)#############")
示例#2
0
def main():
    import sys
    import pathlib

    __dir__ = pathlib.Path(os.path.abspath(__file__))
    sys.path.append(str(__dir__))
    sys.path.append(str(__dir__.parent.parent))

    from models import build_model, build_loss
    from data_loader import get_dataloader
    from utils import Trainer
    from utils import get_post_processing
    from utils import get_metric

    config = anyconfig.load(open('config.yaml', 'rb'))
    train_loader = get_dataloader(config['dataset']['train'])
    validate_loader = get_dataloader(config['dataset']['validate'])
    criterion = build_loss(config['loss']).cuda()
    model = build_model(config['arch'])
    post_p = get_post_processing(config['post_processing'])
    metric = get_metric(config['metric'])

    trainer = Trainer(config=config,
                      model=model,
                      criterion=criterion,
                      train_loader=train_loader,
                      post_process=post_p,
                      metric_cls=metric,
                      validate_loader=validate_loader)
    trainer.train()
示例#3
0
def main(config):
    from mxnet import nd
    from mxnet.gluon.loss import CTCLoss

    from models import get_model
    from data_loader import get_dataloader
    from trainer import Trainer
    from utils import get_ctx, load

    if os.path.isfile(config['dataset']['alphabet']):
        config['dataset']['alphabet'] = ''.join(
            load(config['dataset']['alphabet']))

    prediction_type = config['arch']['args']['prediction']['type']
    num_class = len(config['dataset']['alphabet'])

    # loss 设置
    if prediction_type == 'CTC':
        criterion = CTCLoss()
    else:
        raise NotImplementedError

    ctx = get_ctx(config['trainer']['gpus'])
    model = get_model(num_class, ctx, config['arch']['args'])
    model.hybridize()
    model.initialize(ctx=ctx)

    img_h, img_w = 32, 100
    for process in config['dataset']['train']['dataset']['args'][
            'pre_processes']:
        if process['type'] == "Resize":
            img_h = process['args']['img_h']
            img_w = process['args']['img_w']
            break
    img_channel = 3 if config['dataset']['train']['dataset']['args'][
        'img_mode'] != 'GRAY' else 1
    sample_input = nd.zeros((2, img_channel, img_h, img_w), ctx[0])
    num_label = model.get_batch_max_length(sample_input)

    train_loader = get_dataloader(config['dataset']['train'], num_label,
                                  config['dataset']['alphabet'])
    assert train_loader is not None
    if 'validate' in config['dataset']:
        validate_loader = get_dataloader(config['dataset']['validate'],
                                         num_label,
                                         config['dataset']['alphabet'])
    else:
        validate_loader = None

    config['lr_scheduler']['args']['step'] *= len(train_loader)

    trainer = Trainer(config=config,
                      model=model,
                      criterion=criterion,
                      train_loader=train_loader,
                      validate_loader=validate_loader,
                      sample_input=sample_input,
                      ctx=ctx)
    trainer.train()
示例#4
0
def main(config):
    import torch

    from modeling import build_model, build_loss
    from data_loader import get_dataloader
    from trainer import Trainer
    from utils import CTCLabelConverter, AttnLabelConverter, load
    if os.path.isfile(config['dataset']['alphabet']):
        config['dataset']['alphabet'] = ''.join(
            load(config['dataset']['alphabet']))

    prediction_type = config['arch']['head']['type']

    # loss 设置
    criterion = build_loss(config['loss'])
    if prediction_type == 'CTC':
        converter = CTCLabelConverter(config['dataset']['alphabet'])
    elif prediction_type == 'Attn':
        converter = AttnLabelConverter(config['dataset']['alphabet'])
    else:
        raise NotImplementedError
    img_channel = 3 if config['dataset']['train']['dataset']['args'][
        'img_mode'] != 'GRAY' else 1
    config['arch']['backbone']['in_channels'] = img_channel
    config['arch']['head']['n_class'] = len(converter.character)
    # model = get_model(img_channel, len(converter.character), config['arch']['args'])
    model = build_model(config['arch'])
    img_h, img_w = 32, 100
    for process in config['dataset']['train']['dataset']['args'][
            'pre_processes']:
        if process['type'] == "Resize":
            img_h = process['args']['img_h']
            img_w = process['args']['img_w']
            break
    sample_input = torch.zeros((2, img_channel, img_h, img_w))
    num_label = model.get_batch_max_length(sample_input)
    train_loader = get_dataloader(config['dataset']['train'], num_label)
    assert train_loader is not None
    if 'validate' in config['dataset'] and config['dataset']['validate'][
            'dataset']['args']['data_path'][0] is not None:
        validate_loader = get_dataloader(config['dataset']['validate'],
                                         num_label)
    else:
        validate_loader = None

    trainer = Trainer(config=config,
                      model=model,
                      criterion=criterion,
                      train_loader=train_loader,
                      validate_loader=validate_loader,
                      sample_input=sample_input,
                      converter=converter)
    trainer.train()
示例#5
0
def main(config):
    import torch
    from model import get_model, get_loss, get_converter, get_post_processing
    from metric import get_metric
    from data_loader import get_dataloader
    from tools.rec_trainer import RecTrainer as rec
    from tools.det_trainer import DetTrainer as det
    if torch.cuda.device_count() > 1:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl",
            init_method="env://",
            world_size=torch.cuda.device_count(),
            rank=args.local_rank)
        config['distributed'] = True
    else:
        config['distributed'] = False
    config['local_rank'] = args.local_rank
    train_loader = get_dataloader(config['dataset']['train'],
                                  config['distributed'])
    assert train_loader is not None
    if 'validate' in config['dataset']:
        validate_loader = get_dataloader(config['dataset']['validate'], False)
    else:
        validate_loader = None

    criterion = get_loss(config['loss']).cuda()

    if config.get('post_processing', None):
        post_p = get_post_processing(config['post_processing'])
    else:
        post_p = None

    metric = get_metric(config['metric'])

    if config['arch']['algorithm'] == 'rec':
        converter = get_converter(config['converter'])
        config['arch']['num_class'] = len(converter.character)
        model = get_model(config['arch'])
    else:
        converter = None
        model = get_model(config['arch'])

    trainer = eval(config['arch']['algorithm'])(
        config=config,
        model=model,
        criterion=criterion,
        train_loader=train_loader,
        post_process=post_p,
        metric=metric,
        validate_loader=validate_loader,
        converter=converter)
    trainer.train()
示例#6
0
def get_all_codes(cfg, output_path):

    print(output_path)
    if os.path.exists(output_path):
        return np.load(output_path, allow_pickle=True)['data'].item()
    ensure_dirs(os.path.dirname(output_path))

    print("start over")
    # Dataloader
    train_loader = get_dataloader(cfg, 'train', shuffle=False)
    test_loader = get_dataloader(cfg, 'test', shuffle=False)

    # Trainer
    trainer = Trainer(cfg)
    trainer.to(cfg.device)
    trainer.resume()

    with torch.no_grad():
        vis_dicts = {}
        for phase, loader in [['train', train_loader],
                              ['test', test_loader]]:

            vis_dict = None
            for t, data in enumerate(loader):
                vis_codes = trainer.get_latent_codes(data)
                if vis_dict is None:
                    vis_dict = {}
                    for key, value in vis_codes.items():
                        vis_dict[key] = [value]
                else:
                    for key, value in vis_codes.items():
                        vis_dict[key].append(value)
            for key, value in vis_dict.items():
                if phase == "test" and key == "content_code":
                    continue
                if key == "meta":
                    secondary_keys = value[0].keys()
                    num = len(value)
                    vis_dict[key] = {
                        secondary_key: [to_float(item) for i in range(num) for item in value[i][secondary_key]]
                        for secondary_key in secondary_keys}
                else:
                    vis_dict[key] = torch.cat(vis_dict[key], 0)
                    vis_dict[key] = vis_dict[key].cpu().numpy()
                    vis_dict[key] = to_float(vis_dict[key].reshape(vis_dict[key].shape[0], -1))
            vis_dicts[phase] = vis_dict

        np.savez_compressed(output_path, data=vis_dicts)
        return vis_dicts
示例#7
0
    def __init__(self, model_path, gpu_id=0):
        from model import get_model, get_loss, get_converter
        from data_loader import get_dataloader
        from metric import get_metric
        self.gpu_id = gpu_id

        if self.gpu_id is not None and isinstance(
                self.gpu_id, int) and torch.cuda.is_available():
            self.device = torch.device("cuda:%s" % self.gpu_id)
        else:
            self.device = torch.device("cpu")
        print('device:', self.device)
        checkpoint = torch.load(model_path, map_location=self.device)

        config = checkpoint['config']
        self.config = config
        self.model = get_model(config['arch'])
        # config['converter']['args']['character'] = 'license_plate'
        self.converter = get_converter(config['converter'])
        # self.post_process = get_post_processing(config['post_processing'])
        self.img_mode = config['dataset']['train']['dataset']['args'][
            'img_mode']
        self.model.load_state_dict(checkpoint['state_dict'])
        self.model.to(self.device)
        self.model.eval()
        self.metric = get_metric(config['metric'])
        # config['dataset']['validate']['loader']['num_workers'] = 8
        # config['dataset']['validate']['dataset']['args']['pre_processes'] = [{'type': 'CropWordBox', 'args': [1, 1.2]}]
        if args.img_path is not None:
            config['dataset']['validate']['dataset']['args']['data_path'] = [
                args.img_path
            ]
        self.validate_loader = get_dataloader(config['dataset']['validate'],
                                              config['distributed'])
示例#8
0
def main():

    model = FaceNetModel(embedding_size=args.embedding_size,
                         num_classes=args.num_classes).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.2)

    if args.start_epoch != 0:
        checkpoint = torch.load(
            './log/checkpoint_epoch{}.pth'.format(args.start_epoch - 1))
        model.load_state_dict(checkpoint['state_dict'])

    data_loaders, data_size = get_dataloader(
        args.train_root_dir, args.valid_root_dir, args.train_csv_name,
        args.valid_csv_name, args.num_train_triplets, args.num_valid_triplets,
        args.batch_size, args.num_workers)

    for epoch in range(args.start_epoch, args.num_epochs + args.start_epoch):

        print(80 * '=')
        print(datetime.datetime.now().time())
        print('Epoch [{}/{}]'.format(epoch,
                                     args.num_epochs + args.start_epoch - 1))

        if ((epoch + 1) % 10 == 0):
            data_loaders['train'].dataset.advance_to_the_next_subset()
            data_loaders['valid'].dataset.advance_to_the_next_subset()

        train_valid(model, optimizer, scheduler, epoch, data_loaders,
                    data_size)

    print(80 * '=')
示例#9
0
    def __init__(self, model_path, gpu_id=0):
        from models import build_model
        from data_loader import get_dataloader
        from post_processing import get_post_processing
        from utils import get_metric
        self.gpu_id = gpu_id
        if self.gpu_id is not None and isinstance(
                self.gpu_id, int) and torch.cuda.is_available():
            self.device = torch.device("cuda:%s" % self.gpu_id)
            torch.backends.cudnn.benchmark = True
        else:
            self.device = torch.device("cpu")
        print('load model:', model_path)
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        config = checkpoint['config']
        config['arch']['backbone']['pretrained'] = False

        self.validate_loader = get_dataloader(config['dataset']['validate'],
                                              config['distributed'])

        self.model = build_model(config['arch'].pop('type'), **config['arch'])
        self.model.load_state_dict(checkpoint['state_dict'])
        self.model.to(self.device)

        self.post_process = get_post_processing(config['post_processing'])
        self.metric_cls = get_metric(config['metric'])
示例#10
0
def main():

    model = FaceNetModel(embedding_size=args.embedding_size,
                         num_classes=args.num_classes).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)

    if args.start_epoch != 0:
        checkpoint = torch.load(
            './log/checkpoint_epoch{}.pth'.format(args.start_epoch - 1))
        model.load_state_dict(checkpoint['state_dict'])

    for epoch in range(args.start_epoch, args.num_epochs + args.start_epoch):

        print(80 * '=')
        print('Epoch [{}/{}]'.format(epoch,
                                     args.num_epochs + args.start_epoch - 1))

        data_loaders, data_size = get_dataloader(
            args.train_root_dir, args.valid_root_dir, args.train_csv_name,
            args.valid_csv_name, args.num_train_triplets,
            args.num_valid_triplets, args.batch_size, args.num_workers)

        print("load:", data_size)
        for phase in ['train', 'valid']:
            print(phase, len(data_loaders[phase]))

        train_valid(model, optimizer, scheduler, epoch, data_loaders,
                    data_size)

    print(80 * '=')
示例#11
0
def generate_test():

    output_dir = "../data/PCL/results"
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = torch.load("./exp/22_DANet.pth").module

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    model = model.to(device)
    model.eval()

    labels = [100, 200, 300, 400, 500, 600, 700, 800]

    test_loader = get_dataloader(img_dir="../data/PCL/image_A", mask_dir="../data/PCL/image_A", mode="test",
                                  batch_size=64, num_workers=8)

    for image, _, name in test_loader:
        image = image.to(device, dtype=torch.float32)
        output = model(image)
        pred = torch.softmax(output, dim=1).cpu().detach().numpy()
        pred = semantic_to_mask(pred, labels=labels).squeeze().astype(np.uint16)
        for i in range(pred.shape[0]):
            cv2.imwrite(os.path.join(output_dir, name[i].split('.')[0]) + ".png", pred[i, :, :])
            print(name[i])
示例#12
0
def main():

    # tokenizer
    with open(hparams.dataset_path + "/vocab.pkl", mode='rb') as io:
        vocab = pickle.load(io)
    pad_sequence = PadSequence(length=hparams.max_len,
                               pad_val=vocab.to_indices(vocab.padding_token))
    tokenizer = Tokenizer(vocab=vocab,
                          split_fn=split_sentence,
                          pad_fn=pad_sequence)

    # data loader
    train_loader, valid_loader, test_loader = data_loader.get_dataloader(
        hparams)
    runner = Runner(hparams, vocab=tokenizer.vocab)

    print('Training on ' + str(hparams.device))
    for epoch in range(hparams.num_epochs):
        train_loss, train_acc = runner.run(train_loader, 'train')
        valid_loss, valid_acc = runner.run(valid_loader, 'eval')

        print(
            "[Epoch %d/%d] [Train Loss: %.4f] [Train Acc: %.4f] [Valid Loss: %.4f] [Valid Acc: %.4f]"
            % (epoch + 1, hparams.num_epochs, train_loss, train_acc,
               valid_loss, valid_acc))

        if runner.early_stop(valid_loss, epoch + 1):
            break

    test_loss, test_acc = runner.run(test_loader, 'eval')
    print("Training Finished")
    print("Test Accuracy: %.2f%%" % (100 * test_acc))
示例#13
0
def main(config):
    train_loader, eval_loader = get_dataloader(config['data_loader']['type'], config['data_loader']['args'])
    if os.path.isfile(config['data_loader']['args']['alphabet']):
        config['data_loader']['args']['alphabet'] = str(np.load(config['data_loader']['args']['alphabet']))

    prediction_type = config['arch']['args']['prediction']['type']
    # label转换器设置
    if prediction_type == 'CTC':
        converter = CTCLabelConverter(config['data_loader']['args']['alphabet'])
    else:
        converter = AttnLabelConverter(config['data_loader']['args']['alphabet'])
    num_class = len(converter.character)

    # loss 设置
    if prediction_type == 'CTC':
        criterion = CTCLoss(zero_infinity=True).cuda()
    else:
        criterion = CrossEntropyLoss(ignore_index=0).cuda()  # ignore [GO] token = ignore index 0

    model = get_model(num_class, config)

    config['name'] = config['name'] + '_' + model.name
    trainer = Trainer(config=config,
                      model=model,
                      criterion=criterion,
                      train_loader=train_loader,
                      val_loader=eval_loader,
                      converter=converter,
                      weights_init=weights_init)
    trainer.train()
示例#14
0
文件: eval.py 项目: gaoshangle/dbnet
    def __init__(self, model_path, gpu_id=0):
        from models import build_model
        from data_loader import get_dataloader
        from post_processing import get_post_processing
        from utils import get_metric
        self.gpu_id = gpu_id
        if self.gpu_id is not None and isinstance(self.gpu_id, int) and torch.cuda.is_available():
            self.device = torch.device("cuda:%s" % self.gpu_id)
            torch.backends.cudnn.benchmark = True
        else:
            self.device = torch.device("cpu")
        # print(self.gpu_id) 0
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        config = checkpoint['config']
        config['arch']['backbone']['pretrained'] = False
        config['dataset']['train']['dataset']['args']['data_path'][0] = '/home/share/gaoluoluo/dbnet/datasets/train_zhen.txt'
        config['dataset']['validate']['dataset']['args']['data_path'][0] = '/home/share/gaoluoluo/dbnet/datasets/test_zhen.txt'

        print("config:",config)
        self.validate_loader = get_dataloader(config['dataset']['validate'], config['distributed'])

        self.model = build_model(config['arch'])
        self.model.load_state_dict(checkpoint['state_dict'])
        self.model.to(self.device)

        self.post_process = get_post_processing(config['post_processing'])
        self.metric_cls = get_metric(config['metric'])
示例#15
0
    def __init__(self, model_path, gpu_id=0):
        from models import get_model
        from data_loader import get_dataloader
        #from utils import  get_metric
        self.model_path = model_path

        self.device = torch.device("cuda:%s" % gpu_id)
        if gpu_id is not None:
            torch.backends.cudnn.benchmark = True
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        # print(checkpoint['state_dict'])
        config = checkpoint['config']
        config['distributed'] = False
        config['arch']['args']['pretrained'] = False
        config['arch']['args']['training'] = False
        self.distributed = config['distributed']
        self.validate_loader = get_dataloader(config['dataset']['validate'],
                                              self.distributed)

        self.model = get_model(config['arch'])
        self.model = nn.DataParallel(self.model)
        # print(self.model)
        # exit()
        self.model.load_state_dict(checkpoint['state_dict'])
        self.model.to(self.device)
示例#16
0
def main(config):
    import torch
    from models import build_model, build_loss
    from data_loader import get_dataloader
    from trainer import Trainer
    from post_processing import get_post_processing
    from utils import get_metric
    if torch.cuda.device_count() > 1:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl",
            init_method="env://",
            world_size=torch.cuda.device_count(),
            rank=args.local_rank)
        config['distributed'] = True
    else:
        config['distributed'] = False
    config['local_rank'] = args.local_rank

    train_loader = get_dataloader(config['dataset']['train'],
                                  config['distributed'])
    assert train_loader is not None
    if 'validate' in config['dataset']:
        validate_loader = get_dataloader(config['dataset']['validate'], False)
    else:
        validate_loader = None

    criterion = build_loss(config['loss'].pop('type'), **config['loss']).cuda()

    config['arch']['backbone']['in_channels'] = 3 if config['dataset'][
        'train']['dataset']['args']['img_mode'] != 'GRAY' else 1
    config['arch']['backbone']['pretrained'] = False
    model = build_model(config['arch']['type'], **config['arch'])

    post_p = get_post_processing(config['post_processing'])
    metric = get_metric(config['metric'])

    trainer = Trainer(config=config,
                      model=model,
                      criterion=criterion,
                      train_loader=train_loader,
                      post_process=post_p,
                      metric_cls=metric,
                      validate_loader=validate_loader)
    trainer.train()
示例#17
0
 def __init__(self, _hparams):
     self.hparams = _hparams
     set_seed(_hparams.fixed_seed)
     self.train_loader = get_dataloader(_hparams.train_src_path, _hparams.train_dst_path,
                                        _hparams.batch_size, _hparams.num_workers)
     self.src_vocab, self.dst_vocab = load_vocab(_hparams.train_src_pkl, _hparams.train_dst_pkl)
     self.device = torch.device(_hparams.device)
     self.model = NMT(_hparams.embed_size, _hparams.hidden_size,
                      self.src_vocab, self.dst_vocab, self.device,
                      _hparams.dropout_rate).to(self.device)
     self.optimizer = torch.optim.Adam(self.model.parameters(), lr=_hparams.lr)
def main(config):
    train_loader = get_dataloader(config['data_loader']['type'], config['data_loader']['args'])

    criterion = get_loss(config).cuda()

    model = get_model(config)

    trainer = Trainer(config=config,
                      model=model,
                      criterion=criterion,
                      train_loader=train_loader)
    trainer.train()
示例#19
0
    def __init__(self):
        parser = argparse.ArgumentParser(description='Image Captioning')
        parser.add_argument('--root',
                            default='../../../cocodataset/',
                            type=str)
        parser.add_argument('--crop_size', default=224, type=int)
        parser.add_argument('--epochs', default=100, type=int)
        parser.add_argument('--lr', default=1e-4, type=float)
        parser.add_argument('--batch_size', default=128, help='')
        parser.add_argument('--num_workers', default=4, type=int)
        parser.add_argument('--embed_dim', default=256, type=int)
        parser.add_argument('--hidden_size', default=512, type=int)
        parser.add_argument('--num_layers', default=1, type=int)
        parser.add_argument('--model_path', default='./model/', type=str)
        parser.add_argument('--vocab_path', default='./vocab/', type=str)
        parser.add_argument('--save_step', default=1000, type=int)

        self.args = parser.parse_args()
        self.Multi_GPU = False

        # if torch.cuda.device_count() > 1:
        #     print('Multi GPU Activate!')
        #     print('Using GPU :', int(torch.cuda.device_count()))
        #     self.Multi_GPU = True

        os.makedirs(self.args.model_path, exist_ok=True)

        transform = transforms.Compose([
            transforms.RandomCrop(self.args.crop_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

        with open(self.args.vocab_path + 'vocab.pickle', 'rb') as f:
            data = pickle.load(f)

        self.vocab = data

        self.DataLoader = get_dataloader(root=self.args.root,
                                         transform=transform,
                                         shuffle=True,
                                         batch_size=self.args.batch_size,
                                         num_workers=self.args.num_workers,
                                         vocab=self.vocab)

        self.Encoder = Encoder(embed_dim=self.args.embed_dim)
        self.Decoder = Decoder(embed_dim=self.args.embed_dim,
                               hidden_size=self.args.hidden_size,
                               vocab_size=len(self.vocab),
                               num_layers=self.args.num_layers)
示例#20
0
    def __init__(self):
        parser = argparse.ArgumentParser(description='Image Captioning')
        parser.add_argument('--root',
                            default='../../../cocodataset/',
                            type=str)
        parser.add_argument(
            '--sample_image',
            default='../../../cocodataset/val2017/000000435205.jpg',
            type=str)
        parser.add_argument('--epochs', default=100, type=int)
        parser.add_argument('--lr', default=1e-4, type=float)
        parser.add_argument('--batch_size', default=128, help='')
        parser.add_argument('--num_workers', default=4, type=int)
        parser.add_argument('--embed_dim', default=256, type=int)
        parser.add_argument('--hidden_size', default=512, type=int)
        parser.add_argument('--num_layers', default=1, type=int)
        parser.add_argument('--encoder_path',
                            default='./model/Encoder-100.ckpt',
                            type=str)
        parser.add_argument('--decoder_path',
                            default='./model/Decoder-100.ckpt',
                            type=str)
        parser.add_argument('--vocab_path', default='./vocab/', type=str)

        self.args = parser.parse_args()

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, )),
            transforms.Resize((224, 224))
        ])

        with open(self.args.vocab_path + 'vocab.pickle', 'rb') as f:
            data = pickle.load(f)

        self.vocab = data

        self.DataLoader = get_dataloader(root=self.args.root,
                                         transform=self.transform,
                                         shuffle=True,
                                         batch_size=self.args.batch_size,
                                         num_workers=self.args.num_workers,
                                         vocab=self.vocab)

        self.Encoder = Encoder(embed_dim=self.args.embed_dim)
        self.Decoder = Decoder(embed_dim=self.args.embed_dim,
                               hidden_size=self.args.hidden_size,
                               vocab_size=len(self.vocab),
                               num_layers=self.args.num_layers)
示例#21
0
    def __init__(self,
                 model,
                 num_workers,
                 batch_size,
                 num_epochs,
                 model_save_path,
                 model_save_name,
                 fold,
                 training_history_path,
                 task="cls"):
        self.model = model
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.phases = ["train", "valid"]
        self.model_save_path = model_save_path
        self.model_save_name = model_save_name
        self.fold = fold
        self.training_history_path = training_history_path
        self.criterion = BCEWithLogitsLoss()

        self.optimizer = SGD(self.model.parameters(),
                             lr=1e-02,
                             momentum=0.9,
                             weight_decay=1e-04)
        self.scheduler = ReduceLROnPlateau(self.optimizer,
                                           mode='max',
                                           factor=0.1,
                                           patience=10,
                                           verbose=True,
                                           threshold=1e-8,
                                           min_lr=1e-05,
                                           eps=1e-8)
        self.model = self.model.cuda()
        self.dataloaders = {
            phase: get_dataloader(phase=phase,
                                  fold=fold,
                                  train_batch_size=self.batch_size,
                                  valid_batch_size=self.batch_size,
                                  num_workers=self.num_workers,
                                  task=task)
            for phase in self.phases
        }
        self.loss = {phase: [] for phase in self.phases}
        self.accuracy = {
            phase: np.zeros(shape=(0, 4), dtype=np.float32)
            for phase in self.phases
        }
示例#22
0
def main(config):
    train_loader, eval_loader = get_dataloader(config['data_loader']['type'],
                                               config['data_loader']['args'])

    converter = strLabelConverter(config['data_loader']['args']['alphabet'])
    criterion = CTCLoss(zero_infinity=True)

    model = get_model(config)

    trainer = Trainer(config=config,
                      model=model,
                      criterion=criterion,
                      train_loader=train_loader,
                      val_loader=eval_loader,
                      converter=converter)
    trainer.train()
示例#23
0
    def __init__(self, model_path, gpu_id=0):
        from models import get_model
        from data_loader import get_dataloader
        from post_processing import get_post_processing
        from utils import get_metric
        self.device = torch.device("cuda:%s" % gpu_id)
        if gpu_id is not None:
            torch.backends.cudnn.benchmark = True
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        config = checkpoint['config']
        config['arch']['args']['pretrained'] = False

        self.validate_loader = get_dataloader(config['dataset']['validate'], config['distributed'])

        self.model = get_model(config['arch'])
        self.model.load_state_dict(checkpoint['state_dict'])
        self.model.to(self.device)

        self.post_process = get_post_processing(config['post_processing'])
        self.metric_cls = get_metric(config['metric'])
示例#24
0
    def __init__(self, model, num_workers, batch_size, num_epochs,
                 model_save_path, model_save_name, fold,
                 training_history_path):
        self.model = model
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.phases = ["train", "valid"]
        self.model_save_path = model_save_path
        self.model_save_name = model_save_name
        self.fold = fold
        self.training_history_path = training_history_path
        self.criterion = DiceBCELoss()

        self.optimizer = SGD(self.model.parameters(),
                             lr=1e-02,
                             momentum=0.9,
                             weight_decay=1e-04)
        self.scheduler = ReduceLROnPlateau(self.optimizer,
                                           mode='max',
                                           factor=0.1,
                                           patience=10,
                                           verbose=True,
                                           threshold=1e-8,
                                           min_lr=1e-05,
                                           eps=1e-8)
        self.model = self.model.cuda()
        self.dataloaders = {
            phase: get_dataloader(
                phase=phase,
                fold=fold,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
            )
            for phase in self.phases
        }
        self.loss = {phase: [] for phase in self.phases}
        self.bce_loss = {phase: [] for phase in self.phases}
        self.dice_loss = {phase: [] for phase in self.phases}
        self.dice = {phase: [] for phase in self.phases}
示例#25
0
def main():

    train_data, val_data = get_dataloader(data_list=config.data_list,
                                          root_dir=config.root_dir,
                                          bs=config.batchsize,
                                          sz=config.size)
    val_criterion = get_loss()
    criterion = get_loss(training=True)
    snapmix_criterion = get_loss(tag='snapmix')
    model = get_model(model_name='efficientnet-b3')
    # model = get_model(model_name='tf_efficientnet_b2_ns')
    # model = get_model(model_name='resnext50_32x4d')
    trainer = Trainer(
        config,
        model,
        criterion,
        val_criterion,
        snapmix_criterion,
        train_data,
        val_data,
    )
    trainer.train()
示例#26
0
def main(config):
    if os.path.isfile(config['data_loader']['args']['dataset']['alphabet']):
        config['data_loader']['args']['dataset']['alphabet'] = str(
            np.load(config['data_loader']['args']['dataset']['alphabet']))

    prediction_type = config['arch']['args']['prediction']['type']
    num_class = len(config['data_loader']['args']['dataset']['alphabet'])

    # loss 设置
    if prediction_type == 'CTC':
        criterion = CTCLoss()
    else:
        raise NotImplementedError

    ctx = try_gpu(config['trainer']['gpus'])
    model = get_model(num_class, config['arch']['args'])
    model.hybridize()
    model.initialize(ctx=ctx)

    img_w = config['data_loader']['args']['dataset']['img_w']
    img_h = config['data_loader']['args']['dataset']['img_h']
    train_loader, val_loader = get_dataloader(
        config['data_loader']['type'],
        config['data_loader']['args'],
        num_label=model.get_batch_max_length(img_h=img_h, img_w=img_w,
                                             ctx=ctx))

    config['lr_scheduler']['args']['step'] *= len(train_loader)
    config['name'] = config['name'] + '_' + model.model_name
    trainer = Trainer(config=config,
                      model=model,
                      criterion=criterion,
                      train_loader=train_loader,
                      val_loader=val_loader,
                      ctx=ctx)
    trainer.train()
示例#27
0
    mod = 'custom'
    batch_size = 64
    root_dir = '/home/delvinso/nephro/'
    manifest_path = os.path.join(root_dir, 'all_data', 'all_kidney_manifest.csv')
    out_dir = os.path.join(root_dir, 'output')

    model = Net(task=task, mod=mod)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu");
    print('Device: {}'.format(device))

    checkpoint_path = '/home/delvinso/nephro/output/bladder/custom_no_wts/_best.path.tar'
    model.eval().to(device)

    print('Loading Checkpoint Path: {}....'.format(checkpoint_path))
    checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
    model.load_state_dict(checkpoint['state_dict'], strict=False)

    # initialize dataloader to retrieve all the images in the manifest (which was created by concatenating them together using R)
    print('Retrieving dataloader...')
    dls = get_dataloader(sets=sets,
                         root_dir=root_dir,
                         task=task,
                         manifest_path=manifest_path,
                         batch_size=batch_size,
                         return_pid = True)
    print('Making predictions...')
    df = pred_bladder(dataloaders=dls, sets=sets)

    out_path = os.path.join(out_dir, 'bladder_probs2.csv')
    df.to_csv(out_path)
    print('Done, results saved to {}'.format(out_path))
示例#28
0
    return mean_dice_channels


def dice_single_channel(probability, truth, threshold):
    p = (probability.view(-1) > threshold).float()
    t = (truth.view(-1) > 0.5).float()
    if p.sum() == 0.0 and t.sum() == 0.0:
        dice = 1.0
    else:
        dice = (2.0 * (p * t).sum()) / (p.sum() + t.sum()).item()

    return dice


if __name__ == '__main__':
    from data_loader import get_dataloader
    from torch.nn import BCEWithLogitsLoss
    from model import UResNet34

    dataloader = get_dataloader(phase="train", fold=0, batch_size=4, num_workers=2)
    model = UResNet34()
    model.cuda()
    model.train()
    imgs, masks = next(iter(dataloader))
    imgs, masks = next(iter(dataloader))
    preds = model(imgs.cuda())
    criterion = BCEWithLogitsLoss()
    loss = criterion(preds, masks.cuda())
    print(loss.item())
示例#29
0
def train_val(config):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    train_loader = get_dataloader(img_dir=config.train_img_dir,
                                  mask_dir=config.train_mask_dir,
                                  mode="train",
                                  batch_size=config.batch_size,
                                  num_workers=config.num_workers,
                                  smooth=config.smooth)
    val_loader = get_dataloader(img_dir=config.val_img_dir,
                                mask_dir=config.val_mask_dir,
                                mode="val",
                                batch_size=4,
                                num_workers=config.num_workers)

    writer = SummaryWriter(
        comment="LR_%f_BS_%d_MODEL_%s_DATA_%s" %
        (config.lr, config.batch_size, config.model_type, config.data_type))

    if config.model_type == "UNet":
        model = UNet()
    elif config.model_type == "UNet++":
        model = UNetPP()
    elif config.model_type == "SEDANet":
        model = SEDANet()
    elif config.model_type == "RefineNet":
        model = rf101()
    elif config.model_type == "BASNet":
        model = BASNet(n_classes=8)
    elif config.model_type == "DANet":
        model = DANet(backbone='resnet101',
                      nclass=config.output_ch,
                      pretrained=True,
                      norm_layer=nn.BatchNorm2d)
    elif config.model_type == "Deeplabv3+":
        model = deeplabv3_plus.DeepLabv3_plus(in_channels=3,
                                              num_classes=8,
                                              backend='resnet101',
                                              os=16,
                                              pretrained=True,
                                              norm_layer=nn.BatchNorm2d)
    elif config.model_type == "HRNet_OCR":
        model = seg_hrnet_ocr.get_seg_model()
    elif config.model_type == "scSEUNet":
        model = scSEUNet(pretrained=True, norm_layer=nn.BatchNorm2d)
    else:
        model = UNet()

    if config.iscontinue:
        model = torch.load("./exp/24_Deeplabv3+_0.7825757691389714.pth").module

    for k, m in model.named_modules():
        m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatability

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    model = model.to(device)

    labels = [100, 200, 300, 400, 500, 600, 700, 800]
    objects = ['水体', '交通建筑', '建筑', '耕地', '草地', '林地', '裸土', '其他']

    if config.optimizer == "sgd":
        optimizer = SGD(model.parameters(),
                        lr=config.lr,
                        weight_decay=1e-4,
                        momentum=0.9)
    elif config.optimizer == "adamw":
        optimizer = adamw.AdamW(model.parameters(), lr=config.lr)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

    # weight = torch.tensor([1, 1.5, 1, 2, 1.5, 2, 2, 1.2]).to(device)
    # criterion = nn.CrossEntropyLoss(weight=weight)

    criterion = BasLoss()

    # scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[25, 30, 35, 40], gamma=0.5)
    # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.1, patience=5, verbose=True)
    scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
                                                         T_0=15,
                                                         eta_min=1e-4)

    global_step = 0
    max_fwiou = 0
    frequency = np.array(
        [0.1051, 0.0607, 0.1842, 0.1715, 0.0869, 0.1572, 0.0512, 0.1832])
    for epoch in range(config.num_epochs):
        epoch_loss = 0.0
        cm = np.zeros([8, 8])
        print(optimizer.param_groups[0]['lr'])
        with tqdm(total=config.num_train,
                  desc="Epoch %d / %d" % (epoch + 1, config.num_epochs),
                  unit='img',
                  ncols=100) as train_pbar:
            model.train()

            for image, mask in train_loader:
                image = image.to(device, dtype=torch.float32)
                mask = mask.to(device, dtype=torch.float16)

                pred = model(image)
                loss = criterion(pred, mask)
                epoch_loss += loss.item()

                writer.add_scalar('Loss/train', loss.item(), global_step)
                train_pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_pbar.update(image.shape[0])
                global_step += 1
                # if global_step > 10:
                #     break

            # scheduler.step()
            print("\ntraining epoch loss: " +
                  str(epoch_loss / (float(config.num_train) /
                                    (float(config.batch_size)))))
            torch.cuda.empty_cache()

        val_loss = 0
        with torch.no_grad():
            with tqdm(total=config.num_val,
                      desc="Epoch %d / %d validation round" %
                      (epoch + 1, config.num_epochs),
                      unit='img',
                      ncols=100) as val_pbar:
                model.eval()
                locker = 0
                for image, mask in val_loader:
                    image = image.to(device, dtype=torch.float32)
                    target = mask.to(device, dtype=torch.long).argmax(dim=1)
                    mask = mask.cpu().numpy()
                    pred, _, _, _, _, _, _, _ = model(image)
                    val_loss += F.cross_entropy(pred, target).item()
                    pred = pred.cpu().detach().numpy()
                    mask = semantic_to_mask(mask, labels)
                    pred = semantic_to_mask(pred, labels)
                    cm += get_confusion_matrix(mask, pred, labels)
                    val_pbar.update(image.shape[0])
                    if locker == 25:
                        writer.add_images('mask_a/true',
                                          mask[2, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_a/pred',
                                          pred[2, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_b/true',
                                          mask[3, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_b/pred',
                                          pred[3, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                    locker += 1

                    # break
                miou = get_miou(cm)
                fw_miou = (miou * frequency).sum()
                scheduler.step()

                if fw_miou > max_fwiou:
                    if torch.__version__ == "1.6.0":
                        torch.save(model,
                                   config.result_path + "/%d_%s_%.4f.pth" %
                                   (epoch + 1, config.model_type, fw_miou),
                                   _use_new_zipfile_serialization=False)
                    else:
                        torch.save(
                            model, config.result_path + "/%d_%s_%.4f.pth" %
                            (epoch + 1, config.model_type, fw_miou))
                    max_fwiou = fw_miou
                print("\n")
                print(miou)
                print("testing epoch loss: " + str(val_loss),
                      "FWmIoU = %.4f" % fw_miou)
                writer.add_scalar('mIoU/val', miou.mean(), epoch + 1)
                writer.add_scalar('FWIoU/val', fw_miou, epoch + 1)
                writer.add_scalar('loss/val', val_loss, epoch + 1)
                for idx, name in enumerate(objects):
                    writer.add_scalar('iou/val' + name, miou[idx], epoch + 1)
                torch.cuda.empty_cache()
    writer.close()
    print("Training finished")
示例#30
0
def train_val(config):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    train_loader = get_dataloader(img_dir=config.train_img_dir,
                                  mask_dir=config.train_mask_dir,
                                  mode="train",
                                  batch_size=config.batch_size,
                                  num_workers=config.num_workers,
                                  smooth=config.smooth)
    val_loader = get_dataloader(img_dir=config.val_img_dir,
                                mask_dir=config.val_mask_dir,
                                mode="val",
                                batch_size=config.batch_size,
                                num_workers=config.num_workers)

    writer = SummaryWriter(
        comment="LR_%f_BS_%d_MODEL_%s_DATA_%s" %
        (config.lr, config.batch_size, config.model_type, config.data_type))

    if config.model_type == "UNet":
        model = UNet()
    elif config.model_type == "UNet++":
        model = UNetPP()
    elif config.model_type == "SEDANet":
        model = SEDANet()
    elif config.model_type == "RefineNet":
        model = rf101()
    elif config.model_type == "DANet":
        # src = "./pretrained/60_DANet_0.8086.pth"
        # pretrained_dict = torch.load(src, map_location='cpu').module.state_dict()
        # print("load pretrained params from stage 1: " + src)
        # pretrained_dict.pop('seg1.1.weight')
        # pretrained_dict.pop('seg1.1.bias')
        model = DANet(backbone='resnext101',
                      nclass=config.output_ch,
                      pretrained=True,
                      norm_layer=nn.BatchNorm2d)
        # model_dict = model.state_dict()
        # model_dict.update(pretrained_dict)
        # model.load_state_dict(model_dict)
    elif config.model_type == "Deeplabv3+":
        # src = "./pretrained/Deeplabv3+.pth"
        # pretrained_dict = torch.load(src, map_location='cpu').module.state_dict()
        # print("load pretrained params from stage 1: " + src)
        # # print(pretrained_dict.keys())
        # for key in list(pretrained_dict.keys()):
        #     if key.split('.')[0] == "cbr_last":
        #         pretrained_dict.pop(key)
        model = deeplabv3_plus.DeepLabv3_plus(in_channels=3,
                                              num_classes=config.output_ch,
                                              backend='resnet101',
                                              os=16,
                                              pretrained=True,
                                              norm_layer=nn.BatchNorm2d)
        # model_dict = model.state_dict()
        # model_dict.update(pretrained_dict)
        # model.load_state_dict(model_dict)
    elif config.model_type == "HRNet_OCR":
        model = seg_hrnet_ocr.get_seg_model()
    elif config.model_type == "scSEUNet":
        model = scSEUNet(pretrained=True, norm_layer=nn.BatchNorm2d)
    else:
        model = UNet()

    if config.iscontinue:
        model = torch.load("./exp/13_Deeplabv3+_0.7619.pth",
                           map_location='cpu').module

    for k, m in model.named_modules():
        m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatability

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    model = model.to(device)

    labels = [1, 2, 3, 4, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
    objects = [
        '水体', '道路', '建筑物', '机场', '停车场', '操场', '普通耕地', '农业大棚', '自然草地', '绿地绿化',
        '自然林', '人工林', '自然裸土', '人为裸土', '其它'
    ]
    frequency = np.array([
        0.0279, 0.0797, 0.1241, 0.00001, 0.0616, 0.0029, 0.2298, 0.0107,
        0.1207, 0.0249, 0.1470, 0.0777, 0.0617, 0.0118, 0.0187
    ])

    if config.optimizer == "sgd":
        optimizer = SGD(model.parameters(),
                        lr=config.lr,
                        weight_decay=1e-4,
                        momentum=0.9)
    elif config.optimizer == "adamw":
        optimizer = adamw.AdamW(model.parameters(), lr=config.lr)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

    # weight = torch.tensor([1, 1.5, 1, 2, 1.5, 2, 2, 1.2]).to(device)
    # criterion = nn.CrossEntropyLoss(weight=weight)

    if config.smooth == "all":
        criterion = LabelSmoothSoftmaxCE()
    elif config.smooth == "edge":
        criterion = LabelSmoothCE()
    else:
        criterion = nn.CrossEntropyLoss()

    # scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[25, 30, 35, 40], gamma=0.5)
    # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.1, patience=5, verbose=True)
    scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
                                                         T_0=15,
                                                         eta_min=1e-4)

    global_step = 0
    max_fwiou = 0
    for epoch in range(config.num_epochs):
        epoch_loss = 0.0
        seed = np.random.randint(0, 2, 1)
        seed = 0
        print("seed is ", seed)
        if seed == 1:
            train_loader = get_dataloader(img_dir=config.train_img_dir,
                                          mask_dir=config.train_mask_dir,
                                          mode="train",
                                          batch_size=config.batch_size // 2,
                                          num_workers=config.num_workers,
                                          smooth=config.smooth)
            val_loader = get_dataloader(img_dir=config.val_img_dir,
                                        mask_dir=config.val_mask_dir,
                                        mode="val",
                                        batch_size=config.batch_size // 2,
                                        num_workers=config.num_workers)
        else:
            train_loader = get_dataloader(img_dir=config.train_img_dir,
                                          mask_dir=config.train_mask_dir,
                                          mode="train",
                                          batch_size=config.batch_size,
                                          num_workers=config.num_workers,
                                          smooth=config.smooth)
            val_loader = get_dataloader(img_dir=config.val_img_dir,
                                        mask_dir=config.val_mask_dir,
                                        mode="val",
                                        batch_size=config.batch_size,
                                        num_workers=config.num_workers)

        cm = np.zeros([15, 15])
        print(optimizer.param_groups[0]['lr'])
        with tqdm(total=config.num_train,
                  desc="Epoch %d / %d" % (epoch + 1, config.num_epochs),
                  unit='img',
                  ncols=100) as train_pbar:
            model.train()

            for image, mask in train_loader:
                image = image.to(device, dtype=torch.float32)

                if seed == 0:
                    pass
                elif seed == 1:
                    image = F.interpolate(image,
                                          size=(384, 384),
                                          mode='bilinear',
                                          align_corners=True)
                    mask = F.interpolate(mask.float(),
                                         size=(384, 384),
                                         mode='nearest')

                if config.smooth == "edge":
                    mask = mask.to(device, dtype=torch.float32)
                else:
                    mask = mask.to(device, dtype=torch.long).argmax(dim=1)

                aux_out, out = model(image)
                aux_loss = criterion(aux_out, mask)
                seg_loss = criterion(out, mask)
                loss = aux_loss + seg_loss

                # pred = model(image)
                # loss = criterion(pred, mask)

                epoch_loss += loss.item()

                writer.add_scalar('Loss/train', loss.item(), global_step)
                train_pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_pbar.update(image.shape[0])
                global_step += 1
                # if global_step > 10:
                #     break

            # scheduler.step()
            print("\ntraining epoch loss: " +
                  str(epoch_loss / (float(config.num_train) /
                                    (float(config.batch_size)))))
        torch.cuda.empty_cache()
        val_loss = 0
        with torch.no_grad():
            with tqdm(total=config.num_val,
                      desc="Epoch %d / %d validation round" %
                      (epoch + 1, config.num_epochs),
                      unit='img',
                      ncols=100) as val_pbar:
                model.eval()
                locker = 0
                for image, mask in val_loader:
                    image = image.to(device, dtype=torch.float32)
                    target = mask.to(device, dtype=torch.long).argmax(dim=1)
                    mask = mask.cpu().numpy()
                    _, pred = model(image)
                    val_loss += F.cross_entropy(pred, target).item()
                    pred = pred.cpu().detach().numpy()
                    mask = semantic_to_mask(mask, labels)
                    pred = semantic_to_mask(pred, labels)
                    cm += get_confusion_matrix(mask, pred, labels)
                    val_pbar.update(image.shape[0])
                    if locker == 5:
                        writer.add_images('mask_a/true',
                                          mask[2, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_a/pred',
                                          pred[2, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_b/true',
                                          mask[3, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_b/pred',
                                          pred[3, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                    locker += 1

                    # break
                miou = get_miou(cm)
                fw_miou = (miou * frequency).sum()
                scheduler.step()

                if True:
                    if torch.__version__ == "1.6.0":
                        torch.save(model,
                                   config.result_path + "/%d_%s_%.4f.pth" %
                                   (epoch + 1, config.model_type, fw_miou),
                                   _use_new_zipfile_serialization=False)
                    else:
                        torch.save(
                            model, config.result_path + "/%d_%s_%.4f.pth" %
                            (epoch + 1, config.model_type, fw_miou))
                    max_fwiou = fw_miou
                print("\n")
                print(miou)
                print("testing epoch loss: " + str(val_loss),
                      "FWmIoU = %.4f" % fw_miou)
                writer.add_scalar('FWIoU/val', fw_miou, epoch + 1)
                writer.add_scalar('loss/val', val_loss, epoch + 1)
                for idx, name in enumerate(objects):
                    writer.add_scalar('iou/val' + name, miou[idx], epoch + 1)
            torch.cuda.empty_cache()
    writer.close()
    print("Training finished")