Exemplo n.º 1
0
def main():

    parser = argparse.ArgumentParser(
        description="eval models from checkpoints.")

    parser.add_argument(
        "-cfg",
        "--config_file",
        default="configs/h**o/default.yaml",
        metavar="FILE",
        help="Path to config file",
        type=str,
    )

    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    run_eval(cfg)
Exemplo n.º 2
0
def main():
    args = parse_args()
    cfg.merge_from_file(args.config)
    cfg.freeze

    model = setup_model(cfg)
    load_pretrained_model(cfg, args.config, model, args.pretrained_model)

    dataset = setup_dataset(cfg, 'eval')
    iterator = iterators.MultithreadIterator(dataset,
                                             args.batchsize,
                                             repeat=False,
                                             shuffle=False)

    model.use_preset('evaluate')
    if args.gpu >= 0:
        model.to_gpu(args.gpu)

    in_values, out_values, rest_values = apply_to_iterator(model.predict,
                                                           iterator,
                                                           hook=ProgressHook(
                                                               len(dataset)))
    # delete unused iterators explicitly
    del in_values

    if cfg.dataset.eval == 'COCO':
        eval_coco(out_values, rest_values)
    elif cfg.dataset.eval == 'VOC':
        eval_voc(out_values, rest_values)
    else:
        raise ValueError()
Exemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser("Struct2vec training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    logger, log_path = create_logger(cfg)
    logger.info(cfg)

    graph = networkx.read_edgelist(cfg.DATA.GRAPH_PATH, create_using=networkx.DiGraph(), nodetype=None, data=[('weight', int)])
    model = Struct2Vec(graph, cfg, logger)

    model.train()
    embedding = model.get_embedding()

    eval_embedding(embedding, cfg.DATA.LABEL_PATH, logger)
    vis_embedding(embedding, cfg.DATA.LABEL_PATH, log_path)
Exemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser(
        description="easy2train for Change Detection")

    parser.add_argument(
        "-cfg",
        "--config_file",
        default="configs/h**o/default.yaml",
        metavar="FILE",
        help="Path to config file",
        type=str,
    )

    parser.add_argument(
        "-se",
        "--skip_eval",
        help="Do not eval the models(checkpoints)",
        action="store_true",
    )

    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    run_train(cfg)
Exemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch RecLib')

    parser.add_argument('--config-file',
                        default='Configs/default.yaml',
                        metavar='FILE',
                        help='path to configuration file',
                        type=str)

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.freeze()

    # # create output dir
    # experiment_dir = get_unique_temp_folder(cfg.OUTPUT_DIR)

    print("Collecting env info (may take some time)\n")
    print(get_pretty_env_info())
    print("Loading configuration file from {}".format(args.config_file))
    print('Running with configuration: \n')
    print(cfg)

    # set random seed for pytorch and numpy
    if cfg.SEED != 0:
        print("Using manual seed: {}".format(cfg.SEED))
        torch.manual_seed(cfg.SEED)
        torch.cuda.manual_seed(cfg.SEED)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(cfg.SEED)
    else:
        print("Using random seed")
        torch.backends.cudnn.benchmark = True

    # create dataloader
    train_loader, field_info = make_dataloader(cfg, split='train')
    valid_loader, _ = make_dataloader(cfg, split='valid')
    test_loader, _ = make_dataloader(cfg, split='test')

    # create model
    model = get_model(cfg, field_info)

    best_model = train(cfg, model, train_loader, valid_loader, save=False)
    auc, log_loss = test(cfg, best_model, test_loader, device=cfg.DEVICE)
    print("*" * 20)
    print("* Test AUC: {:.5f} *".format(auc))
    print("* Test Log Loss: {:.5f} *".format(log_loss))
    print("*" * 20)

    model.eval()
    macs, params = profile_model(model, test_loader, device=cfg.DEVICE)
    print("*" * 20)
    print("* MACs (M): {} *".format(macs / 10**6))
    print("* #Params (M): {} *".format(params / 10**6))
    print("* Model Size (MB): {} *".format(params * 8 /
                                           10**6))  # torch.float64 by default
    print('*' * 20)
Exemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch NDDR Training")
    parser.add_argument(
        "--config-file",
        default="configs/vgg16_nddr_pret.yaml",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.EXPERIMENT_NAME = args.config_file.split('/')[-1][:-5]
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    # load the data
    test_loader = torch.utils.data.DataLoader(
        MultiTaskDataset(
            data_dir=cfg.DATA_DIR,
            data_list_1=cfg.TEST.DATA_LIST_1,
            data_list_2=cfg.TEST.DATA_LIST_2,
            output_size=cfg.TEST.OUTPUT_SIZE,
            random_scale=cfg.TEST.RANDOM_SCALE,
            random_mirror=cfg.TEST.RANDOM_MIRROR,
            random_crop=cfg.TEST.RANDOM_CROP,
            ignore_label=cfg.IGNORE_LABEL,
        ),
        batch_size=cfg.TEST.BATCH_SIZE, shuffle=False)

    net1 = DeepLabLargeFOVBN(3, cfg.MODEL.NET1_CLASSES, weights='')
    net2 = DeepLabLargeFOVBN(3, cfg.MODEL.NET2_CLASSES, weights='')
    model = NDDRNet(net1, net2,
                    shortcut=cfg.MODEL.SHORTCUT,
                    bn_before_relu=cfg.MODEL.BN_BEFORE_RELU)
    ckpt_path = os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME, 'ckpt-%s.pth' % str(cfg.TEST.CKPT_ID).zfill(5))
    print("Evaluating Checkpoint at %s" % ckpt_path)
    ckpt = torch.load(ckpt_path)
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()
    if cfg.CUDA:
        model = model.cuda()

    mIoU, pixel_acc, angle_metrics = evaluate(test_loader, model)
    print('Mean IoU: {:.3f}'.format(mIoU))
    print('Pixel Acc: {:.3f}'.format(pixel_acc))
    for k, v in angle_metrics.items():
        print('{}: {:.3f}'.format(k, v))
def _parse_args():
    parser = argparse.ArgumentParser(description='Training Config',
                                     add_help=False)
    parser.add_argument('--config_file',
                        default='',
                        type=str,
                        metavar='FILE',
                        required=True,
                        help='path to config file')
    parser.add_argument(
        '--initial_checkpoint',
        default=None,
        type=str,
        metavar='FILE',
        help='Initialize model from this checkpoint (default: none)')
    parser.add_argument(
        '--resume',
        default=None,
        type=str,
        metavar='FILE',
        help=
        'Resume full model and optimizer state from checkpoint (default: none)'
    )
    parser.add_argument(
        '--no_resume_opt',
        action='store_true',
        default=False,
        help='prevent resume of optimizer state when resuming model')
    parser.add_argument('--start_epoch',
                        default=None,
                        type=int,
                        metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--seed',
                        type=int,
                        default=1024,
                        metavar='S',
                        help='random seed (default: 1024)')
    parser.add_argument('--gpu',
                        type=str,
                        default='',
                        help='gpu list to use, for example 1,2,3')

    args = parser.parse_args()

    if len(args.config_file) == 0:
        raise ValueError('Please input config file path!')
    cfg.merge_from_file(args.config_file)
    cfg.freeze()
    return cfg, args
Exemplo n.º 8
0
def main():

    num_gpus = int(os.environ["GPU_NUM"]) if "GPU_NUM" in os.environ else 1

    cfg.merge_from_file('configs/inference.yml')
    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR

    if output_dir and not os.path.exists(output_dir):
        mkdir(output_dir)

    cudnn.benchmark = True

    train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)
    model = build_model(cfg, num_classes)
    model.load_state_dict(torch.load(cfg.TEST.WEIGHT))

    inference(cfg, model, val_loader, num_query, num_gpus)
Exemplo n.º 9
0
def main(device):
    parser = argparse.ArgumentParser("SDNE training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    logger, log_path = create_logger(cfg)
    logger.info(cfg)

    graph = networkx.read_edgelist(cfg.DATA.GRAPH_PATH,
                                   create_using=networkx.DiGraph(),
                                   nodetype=None,
                                   data=[('weight', int)])
    model = SDNE(graph, cfg).to(device)

    L_matrix, adj_matrix = model.get_matrix()
    criterion = MixLoss(cfg, L_matrix, adj_matrix, device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=cfg.WORD2VEC.LR,
                                 weight_decay=cfg.SDNE.L2)

    train(adj_matrix, model, cfg, criterion, optimizer, device)
    embedding = model.get_embedding(adj_matrix.to(device))

    eval_embedding(embedding, cfg.DATA.LABEL_PATH, logger)
    vis_embedding(embedding, cfg.DATA.LABEL_PATH, log_path)
Exemplo n.º 10
0
def main():
    args = parse_args()
    cfg.merge_from_file(args.config)
    cfg.freeze()

    model = setup_model(cfg)
    load_pretrained_model(cfg, args.config, model, args.pretrained_model)

    model.use_preset(args.use_preset)
    if args.gpu >= 0:
        model.to_gpu(args.gpu)

    if args.webcam:
        data_iter = WebCamIter()
        data_iter.start_device()
        wait = 1
    elif args.indir is not None:
        data_iter = DirectoryIter(args.indir)
        wait = 0
    else:
        data_iter = setup_dataset(cfg, args.split)
        wait = 0
    visualizer = Visualizer(cfg.dataset.eval)

    for data in data_iter:
        if type(data) == tuple:
            img = data[0]
        else:
            img = data
        output = [[v[0][:10]] for v in model.predict([img.copy()])]
        result = visualizer.visualize(img, output)

        cv2.imshow('result', result)
        key = cv2.waitKey(wait) & 0xff
        if key == ord('q'):
            break
    cv2.destroyAllWindows()
    if args.webcam:
        data_iter.stop_device()
Exemplo n.º 11
0
def main():
    args = parse_args()
    cfg.merge_from_file(args.config)
    cfg.freeze

    comm = chainermn.create_communicator('pure_nccl')
    device = comm.intra_rank

    model = setup_model(cfg)
    load_pretrained_model(cfg, args.config, model, args.pretrained_model)
    dataset = setup_dataset(cfg, 'eval')

    model.use_preset('evaluate')
    chainer.cuda.get_device_from_id(device).use()
    model.to_gpu()

    if not comm.rank == 0:
        apply_to_iterator(model.predict, None, comm=comm)
        return

    iterator = iterators.MultithreadIterator(dataset,
                                             args.batchsize * comm.size,
                                             repeat=False,
                                             shuffle=False)

    in_values, out_values, rest_values = apply_to_iterator(model.predict,
                                                           iterator,
                                                           hook=ProgressHook(
                                                               len(dataset)),
                                                           comm=comm)
    # delete unused iterators explicitly
    del in_values

    if cfg.dataset.eval == 'COCO':
        eval_coco(out_values, rest_values)
    elif cfg.dataset.eval == 'VOC':
        eval_voc(out_values, rest_values)
    else:
        raise ValueError()
Exemplo n.º 12
0
def main(device):
    parser = argparse.ArgumentParser("LINE training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    logger, log_path = create_logger(cfg)
    logger.info(cfg)

    graph = networkx.read_edgelist(cfg.DATA.GRAPH_PATH, create_using=networkx.DiGraph(), nodetype=None, data=[('weight', int)])
    model = LINE(graph, cfg).to(device)
    v2ind, ind2v = model.get_mapping()
    sampler = Sampler(graph, v2ind, batch_size=cfg.SAMPLE.BATCHSIZE)

    criterion = KLLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.WORD2VEC.LR)

    train(sampler, model, cfg, criterion, optimizer, device)
    embedding = model.get_embedding()

    eval_embedding(embedding, cfg.DATA.LABEL_PATH, logger)
    vis_embedding(embedding, cfg.DATA.LABEL_PATH, log_path)
Exemplo n.º 13
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Training")
    parser.add_argument('--gpu', default='0', type=str)
    parser.add_argument(
        "--config_file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )
    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    # build model, optimizer and scheduler
    model = make_model(cfg)
    model = model.to(cfg.DEVICE)
    optimizer = build_optimizer(cfg, model)
    print('optimizer built!')
    # NOTE: add separate optimizers to train single object predictor and interaction predictor

    if cfg.USE_WANDB:
        logger = Logger("FOL", cfg, project=cfg.PROJECT, viz_backend="wandb")
    else:
        logger = logging.Logger("FOL")

    dataloader_params = {
        "batch_size": cfg.SOLVER.BATCH_SIZE,
        "shuffle": True,
        "num_workers": cfg.DATALOADER.NUM_WORKERS
    }

    # get dataloaders
    train_dataloader = make_dataloader(cfg, 'train')
    val_dataloader = make_dataloader(cfg, 'val')
    test_dataloader = make_dataloader(cfg, 'test')
    print('Dataloader built!')
    # get train_val_test engines
    do_train, do_val, inference = build_engine(cfg)
    print('Training engine built!')
    if hasattr(logger, 'run_id'):
        run_id = logger.run_id
    else:
        run_id = 'no_wandb'

    save_checkpoint_dir = os.path.join(cfg.CKPT_DIR, run_id)
    if not os.path.exists(save_checkpoint_dir):
        os.makedirs(save_checkpoint_dir)

    # NOTE: hyperparameter scheduler
    model.param_scheduler = ParamScheduler()
    model.param_scheduler.create_new_scheduler(name='kld_weight',
                                               annealer=sigmoid_anneal,
                                               annealer_kws={
                                                   'device': cfg.DEVICE,
                                                   'start': 0,
                                                   'finish': 100.0,
                                                   'center_step': 400.0,
                                                   'steps_lo_to_hi': 100.0,
                                               })

    model.param_scheduler.create_new_scheduler(name='z_logit_clip',
                                               annealer=sigmoid_anneal,
                                               annealer_kws={
                                                   'device': cfg.DEVICE,
                                                   'start': 0.05,
                                                   'finish': 5.0,
                                                   'center_step': 300.0,
                                                   'steps_lo_to_hi': 300.0 / 5.
                                               })

    if cfg.SOLVER.scheduler == 'exp':
        # exponential schedule
        lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                        gamma=cfg.SOLVER.GAMMA)
    elif cfg.SOLVER.scheduler == 'plateau':
        # Plateau scheduler
        lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                            factor=0.2,
                                                            patience=5,
                                                            min_lr=1e-07,
                                                            verbose=1)
    else:
        lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                      milestones=[25, 40],
                                                      gamma=0.2)

    print('Schedulers built!')

    for epoch in range(cfg.SOLVER.MAX_EPOCH):
        logger.info("Epoch:{}".format(epoch))
        do_train(cfg,
                 epoch,
                 model,
                 optimizer,
                 train_dataloader,
                 cfg.DEVICE,
                 logger=logger,
                 lr_scheduler=lr_scheduler)
        val_loss = do_val(cfg,
                          epoch,
                          model,
                          val_dataloader,
                          cfg.DEVICE,
                          logger=logger)
        if (epoch + 1) % 1 == 0:
            inference(cfg,
                      epoch,
                      model,
                      test_dataloader,
                      cfg.DEVICE,
                      logger=logger,
                      eval_kde_nll=False)

        torch.save(
            model.state_dict(),
            os.path.join(save_checkpoint_dir,
                         'Epoch_{}.pth'.format(str(epoch).zfill(3))))

        # update LR
        if cfg.SOLVER.scheduler != 'exp':
            lr_scheduler.step(val_loss)
Exemplo n.º 14
0
def train(cfg_path, device='cuda'):
    if cfg_path is not None:
        cfg.merge_from_file(cfg_path)
    cfg.freeze()

    if not os.path.isdir(cfg.LOG_DIR):
        os.makedirs(cfg.LOG_DIR)
    if not os.path.isdir(cfg.SAVE_DIR):
        os.makedirs(cfg.SAVE_DIR)

    model = UNet(cfg.NUM_CHANNELS, cfg.NUM_CLASSES)
    model.to(device)

    train_data_loader = build_data_loader(cfg, 'train')
    if cfg.VAL:
        val_data_loader = build_data_loader(cfg, 'val')
    else:
        val_data_loader = None

    optimizer = build_optimizer(cfg, model)
    lr_scheduler = build_lr_scheduler(cfg, optimizer)
    criterion = get_loss_func(cfg)
    writer = SummaryWriter(cfg.LOG_DIR)

    iter_counter = 0
    loss_meter = AverageMeter()
    val_loss_meter = AverageMeter()
    min_val_loss = 1e10

    print('Training Start')
    for epoch in range(cfg.SOLVER.MAX_EPOCH):
        print('Epoch {}/{}'.format(epoch + 1, cfg.SOLVER.MAX_EPOCH))
        if lr_scheduler is not None:
            lr_scheduler.step(epoch)
        for data in train_data_loader:
            iter_counter += 1

            imgs, annots = data
            imgs = imgs.to(device)
            annots = annots.to(device)

            y = model(imgs)
            optimizer.zero_grad()
            loss = criterion(y, annots)
            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())

            if iter_counter % 10 == 0:
                writer.add_scalars('loss', {'train': loss_meter.avg},
                                   iter_counter)
                loss_meter.reset()
            if lr_scheduler is not None:
                writer.add_scalar('learning rate',
                                  optimizer.param_groups[0]['lr'],
                                  iter_counter)
            save_as_checkpoint(model, optimizer,
                               os.path.join(cfg.SAVE_DIR, 'checkpoint.pth'),
                               epoch, iter_counter)

        # Skip validation when cfg.VAL is False
        if val_data_loader is None:
            continue

        for data in val_data_loader:
            val_loss_meter.reset()
            with torch.no_grad():
                imgs, annots = data
                imgs = imgs.to(device)
                annots = annots.to(device)

                y = model(imgs)
                optimizer.zero_grad()
                loss = criterion(y, annots)
                val_loss_meter.update(loss.item())
        if val_loss_meter.avg < min_val_loss:
            min_val_loss = val_loss_meter.avg
            writer.add_scalars('loss', {'val': val_loss_meter.avg},
                               iter_counter)
            # save model if validation loss is minimum
            torch.save(model.state_dict(),
                       os.path.join(cfg.SAVE_DIR, 'min_val_loss.pth'))
Exemplo n.º 15
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch NDDR Training")
    parser.add_argument(
        "--config-file",
        default="configs/vgg16_nddr_pret.yaml",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.EXPERIMENT_NAME = args.config_file.split('/')[-1][:-5]
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    if not os.path.exists(os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME)):
        os.makedirs(os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME))

    # load the data
    train_loader = torch.utils.data.DataLoader(
        MultiTaskDataset(
            data_dir=cfg.DATA_DIR,
            data_list_1=cfg.TRAIN.DATA_LIST_1,
            data_list_2=cfg.TRAIN.DATA_LIST_2,
            output_size=cfg.TRAIN.OUTPUT_SIZE,
            random_scale=cfg.TRAIN.RANDOM_SCALE,
            random_mirror=cfg.TRAIN.RANDOM_MIRROR,
            random_crop=cfg.TRAIN.RANDOM_CROP,
            ignore_label=cfg.IGNORE_LABEL,
        ),
        batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True)

    # load the data
    if cfg.TRAIN.EVAL_CKPT:
        test_loader = torch.utils.data.DataLoader(
            MultiTaskDataset(
                data_dir=cfg.DATA_DIR,
                data_list_1=cfg.TEST.DATA_LIST_1,
                data_list_2=cfg.TEST.DATA_LIST_2,
                output_size=cfg.TEST.OUTPUT_SIZE,
                random_scale=cfg.TEST.RANDOM_SCALE,
                random_mirror=cfg.TEST.RANDOM_MIRROR,
                random_crop=cfg.TEST.RANDOM_CROP,
                ignore_label=cfg.IGNORE_LABEL,
            ),
            batch_size=cfg.TEST.BATCH_SIZE, shuffle=False)

    timestamp = datetime.datetime.now().strftime("%Y-%m-%d~%H:%M:%S")
    experiment_log_dir = os.path.join(cfg.LOG_DIR, cfg.EXPERIMENT_NAME, timestamp)
    if not os.path.exists(experiment_log_dir):
        os.makedirs(experiment_log_dir)
    writer = SummaryWriter(logdir=experiment_log_dir)

    net1 = DeepLabLargeFOVBN(3, cfg.MODEL.NET1_CLASSES, weights=cfg.TRAIN.WEIGHT_1)
    net2 = DeepLabLargeFOVBN(3, cfg.MODEL.NET2_CLASSES, weights=cfg.TRAIN.WEIGHT_2)
    model = NDDRNet(net1, net2,
                    shortcut=cfg.MODEL.SHORTCUT,
                    bn_before_relu=cfg.MODEL.BN_BEFORE_RELU)
    
    if cfg.CUDA:
        model = model.cuda()
    model.train()
    steps = 0

    seg_loss = nn.CrossEntropyLoss(ignore_index=255)

    # hacky way to pick params
    nddr_params = []
    fc8_weights = []
    fc8_bias = []
    base_params = []
    for k, v in model.named_parameters():
        if 'nddrs' in k:
            nddr_params.append(v)
        elif cfg.MODEL.FC8_ID in k:
            if 'weight' in k:
                fc8_weights.append(v)
            else:
                assert 'bias' in k
                fc8_bias.append(v)
        else:
            base_params.append(v)
    assert len(nddr_params) > 0 and len(fc8_weights) > 0 and len(fc8_bias) > 0

    parameter_dict = [
        {'params': base_params},
        {'params': fc8_weights, 'lr': cfg.TRAIN.LR * cfg.TRAIN.FC8_WEIGHT_FACTOR},
        {'params': fc8_bias, 'lr': cfg.TRAIN.LR * cfg.TRAIN.FC8_BIAS_FACTOR},
        {'params': nddr_params, 'lr': cfg.TRAIN.LR * cfg.TRAIN.NDDR_FACTOR}
    ]
    optimizer = optim.SGD(parameter_dict, lr=cfg.TRAIN.LR, momentum=cfg.TRAIN.MOMENTUM,
                          weight_decay=cfg.TRAIN.WEIGHT_DECAY)
    
    if cfg.TRAIN.SCHEDULE == 'Poly':
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: (1 - float(step) / cfg.TRAIN.STEPS)**cfg.TRAIN.POWER, last_epoch=-1)
    elif cfg.TRAIN.SCHEDULE == 'Cosine':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg.TRAIN.STEPS)
    else:
        raise NotImplementedError

    while steps < cfg.TRAIN.STEPS:
        for batch_idx, (image, label_1, label_2) in enumerate(train_loader):
            if cfg.CUDA:
                image, label_1, label_2 = image.cuda(), label_1.cuda(), label_2.cuda()
            optimizer.zero_grad()
            out1, out2 = model(image)

            # loss_seg = get_seg_loss(out1, label_1, 40, 255)
            loss_seg = seg_loss(out1, label_1.squeeze(1))
            loss_normal = get_normal_loss(out2, label_2, 255)

            loss = loss_seg + cfg.TRAIN.NORMAL_FACTOR * loss_normal

            loss.backward()

            optimizer.step()
            scheduler.step()

            # Print out the loss periodically.
            if steps % cfg.TRAIN.LOG_INTERVAL == 0:
                print('Train Step: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLoss1: {:.6f}\tLoss2: {:.6f}'.format(
                    steps, batch_idx * len(image), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.data.item(),
                    loss_seg.data.item(), loss_normal.data.item()))

                # Log to tensorboard
                writer.add_scalar('lr', scheduler.get_lr()[0], steps)
                writer.add_scalar('loss/overall', loss.data.item(), steps)
                writer.add_scalar('loss/seg', loss_seg.data.item(), steps)
                writer.add_scalar('loss/normal', loss_normal.data.item(), steps)

                writer.add_image('image', process_image(image[0]), steps)
                seg_pred, seg_gt = process_seg_label(
                    out1.argmax(dim=1)[0].detach(),
                    label_1.squeeze(1)[0],
                    cfg.MODEL.NET1_CLASSES
                )
                writer.add_image('seg/pred', seg_pred, steps)
                writer.add_image('seg/gt', seg_gt, steps)
                normal_pred, normal_gt = process_normal_label(out2[0].detach(), label_2[0], 255)
                writer.add_image('normal/pred', normal_pred, steps)
                writer.add_image('normal/gt', normal_gt, steps)

            if steps % cfg.TRAIN.SAVE_INTERVAL == 0:
                checkpoint = {
                    'cfg': cfg,
                    'step': steps,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'loss': loss,
                    'loss_seg': loss_seg,
                    'loss_normal': loss_normal,
                    'mIoU': None,
                    'PixelAcc': None,
                    'angle_metrics': None,
                }

                if cfg.TRAIN.EVAL_CKPT:
                    model.eval()
                    mIoU, pixel_acc, angle_metrics = evaluate(test_loader, model)
                    writer.add_scalar('eval/mIoU', mIoU, steps)
                    writer.add_scalar('eval/PixelAcc', pixel_acc, steps)
                    for k, v in angle_metrics.items():
                        writer.add_scalar('eval/{}'.format(k), v, steps)
                    checkpoint['mIoU'] = mIoU
                    checkpoint['PixelAcc'] = pixel_acc
                    checkpoint['angle_metrics'] = angle_metrics
                    model.train()

                torch.save(checkpoint, os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME,
                                                    'ckpt-%s.pth' % str(steps).zfill(5)))

            steps += 1
            if steps >= cfg.TRAIN.STEPS:
                break
Exemplo n.º 16
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='PyTorch Image Classification')
    parser.add_argument('--dataset',
                        type=str,
                        default='cifar100',
                        help='specify training dataset')
    parser.add_argument('--session',
                        type=int,
                        default='1',
                        help='training session to recoder multiple runs')
    parser.add_argument('--arch',
                        type=str,
                        default='resnet110',
                        help='specify network architecture')
    parser.add_argument('--bs',
                        dest="batch_size",
                        type=int,
                        default=128,
                        help='training batch size')
    parser.add_argument('--gpu0-bs',
                        dest="gpu0_bs",
                        type=int,
                        default=0,
                        help='training batch size on gpu0')
    parser.add_argument('--add-ccn',
                        type=str,
                        default='no',
                        help='add cross neruon communication')
    parser.add_argument('--mgpus',
                        type=str,
                        default="no",
                        help='multi-gpu training')
    parser.add_argument('--resume',
                        dest="resume",
                        type=int,
                        default=0,
                        help='resume epoch')

    args = parser.parse_args()
    cfg.merge_from_file(osp.join("configs", args.dataset + ".yaml"))
    cfg.dataset = args.dataset
    cfg.arch = args.arch
    cfg.add_cross_neuron = True if args.add_ccn == "yes" else False
    use_cuda = True if torch.cuda.is_available() else False
    cfg.use_cuda = use_cuda
    cfg.training.batch_size = args.batch_size
    cfg.mGPUs = True if args.mgpus == "yes" else False

    torch.manual_seed(cfg.initialize.seed)
    device = torch.device("cuda" if use_cuda else "cpu")
    train_loader, test_loader = create_data_loader(cfg)
    model = CrossNeuronNet(cfg)
    print("parameter numer: %d" % (count_parameters(model)))
    with torch.cuda.device(0):
        if args.dataset == "cifar100":
            flops, params = get_model_complexity_info(
                model, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
            # flops, params = profile(model, input_size=(1, 3, 32, 32))
        elif args.dataset == "imagenet":
            flops, params = get_model_complexity_info(
                model, (3, 224, 224),
                as_strings=True,
                print_per_layer_stat=True)
            # flops, params = profile(model, input_size=(1, 3, 224, 224))
        print('Flops: {}'.format(flops))
        print('Params: {}'.format(params))

    model = model.to(device)

    # optimizer_policy = model.get_optim_policies()
    optimizer = optim.SGD(model.parameters(),
                          lr=cfg.optimizer.lr,
                          momentum=cfg.optimizer.momentum,
                          weight_decay=cfg.optimizer.weight_decay)
    # optimizer = optim.Adam(model.parameters(), lr=1e-3)
    if cfg.mGPUs:
        if args.gpu0_bs > 0:
            model = BalancedDataParallel(args.gpu0_bs, model).to(device)
        else:
            model = nn.DataParallel(model).to(device)

    lr = cfg.optimizer.lr
    checkpoint_tag = osp.join("checkponts", args.dataset, args.arch)
    if not osp.exists(checkpoint_tag):
        os.makedirs(checkpoint_tag)

    if args.resume > 0:
        ckpt_path = osp.join(checkpoint_tag,
                             ("ccn" if cfg.add_cross_neuron else "plain") +
                             "_{}_{}.pth".format(args.session, args.resume))
        print("resume model from {}".format(ckpt_path))
        ckpt = torch.load(ckpt_path)
        model.load_state_dict(ckpt["model"])
        print("resume model succesfully")
        acc = test(cfg, model, device, test_loader)

    best_acc = 0
    for epoch in range(args.resume + 1, cfg.optimizer.max_epoch + 1):
        if epoch in cfg.optimizer.lr_decay_schedule:
            adjust_learning_rate(optimizer, cfg.optimizer.lr_decay_gamma)
            lr *= cfg.optimizer.lr_decay_gamma
        print('Train Epoch: {} learning rate: {}'.format(epoch, lr))
        tic = time.time()
        train(cfg, model, device, train_loader, optimizer, epoch)
        acc = test(cfg, model, device, test_loader)
        time_cost = time.time() - tic
        if acc > best_acc:
            best_acc = acc
        print(
            '\nModel: {} Best Accuracy-Baseline: {}\tTime Cost per Epoch: {}\n'
            .format(
                checkpoint_tag + ("ccn" if args.add_ccn == "yes" else "plain"),
                best_acc, time_cost))

        if epoch % cfg.log.checkpoint_interval == 0:
            checkpoint = {
                "arch": cfg.arch,
                "model": model.state_dict(),
                "epoch": epoch,
                "lr": lr,
                "test_acc": acc,
                "best_acc": best_acc
            }
            torch.save(
                checkpoint,
                osp.join(checkpoint_tag,
                         ("ccn" if cfg.add_cross_neuron else "plain") +
                         "_{}_{}.pth".format(args.session, epoch)))
Exemplo n.º 17
0
from skimage import io
from skimage.transform import resize
from model.utils.simplesum_octconv import simplesum

parser = argparse.ArgumentParser(description='PyTorch SOD')

parser.add_argument(
    "--config",
    default="",
    metavar="FILE",
    help="path to config file",
    type=str,
)
args = parser.parse_args()
assert os.path.isfile(args.config)
cfg.merge_from_file(args.config)

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.GPU)
if cfg.TASK == '':
    cfg.TASK = cfg.MODEL.ARCH

print(cfg)


def main():
    global cfg
    model_lib = importlib.import_module("model." + cfg.MODEL.ARCH)
    predefine_file = cfg.TEST.MODEL_CONFIG
    model = model_lib.build_model(predefine=predefine_file)
    model.cuda()
Exemplo n.º 18
0
    arguments = dict()
    arguments["iteration"] = 0

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    checkpointer = Checkpointer(model, optimizer, scheduler, cfg.SAVE_DIR)

    do_train(cfg, model, train_loader, val_loader, optimizer, scheduler,
             criterion, checkpointer, device, checkpoint_period, arguments,
             logger)


def parse_args():
    """
  Parse input arguments
  """
    parser = argparse.ArgumentParser(description='Train a retrieval network')
    parser.add_argument(
        '--cfg',
        dest='cfg_file',
        help='config file',
        default='/home/songkun/PycharmProjects/TPAMI/configs/example.yaml',
        type=str)
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()
    cfg.merge_from_file(args.cfg_file)
    train(cfg)
Exemplo n.º 19
0
def main():
    args = parse_args()
    cfg.merge_from_file(args.config)
    cfg.freeze()

    if hasattr(multiprocessing, 'set_start_method'):
        multiprocessing.set_start_method('forkserver')
        p = multiprocessing.Process()
        p.start()
        p.join()

    comm = chainermn.create_communicator('pure_nccl')
    assert comm.size == cfg.n_gpu
    device = comm.intra_rank

    if comm.rank == 0:
        print(cfg)

    model = setup_model(cfg)
    train_chain = setup_train_chain(cfg, model)
    chainer.cuda.get_device_from_id(device).use()
    train_chain.to_gpu()

    train_dataset = TransformDataset(setup_dataset(cfg, 'train'),
                                     ('img', 'bbox', 'label'), Transform())
    if comm.rank == 0:
        indices = np.arange(len(train_dataset))
    else:
        indices = None
    indices = chainermn.scatter_dataset(indices, comm, shuffle=True)
    train_dataset = train_dataset.slice[indices]
    train_iter = chainer.iterators.MultiprocessIterator(
        train_dataset,
        cfg.n_sample_per_gpu,
        n_processes=cfg.n_worker,
        shared_mem=100 * 1000 * 1000 * 4)
    optimizer = chainermn.create_multi_node_optimizer(setup_optimizer(cfg),
                                                      comm)
    optimizer = optimizer.setup(train_chain)
    optimizer = add_hook_optimizer(optimizer, cfg)
    freeze_params(cfg, train_chain.model)

    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                device=device,
                                                converter=converter)
    trainer = training.Trainer(updater, (cfg.solver.n_iteration, 'iteration'),
                               get_outdir(args.config))

    # extention
    if comm.rank == 0:
        log_interval = 10, 'iteration'
        trainer.extend(training.extensions.LogReport(trigger=log_interval))
        trainer.extend(training.extensions.observe_lr(), trigger=log_interval)
        trainer.extend(training.extensions.PrintReport([
            'epoch',
            'iteration',
            'lr',
            'main/loss',
            'main/loss/loc',
            'main/loss/conf',
        ]),
                       trigger=log_interval)
        trainer.extend(training.extensions.ProgressBar(update_interval=10))

        trainer.extend(training.extensions.snapshot(),
                       trigger=(10000, 'iteration'))
        trainer.extend(training.extensions.snapshot_object(
            model, 'model_iter_{.updater.iteration}'),
                       trigger=(cfg.solver.n_iteration, 'iteration'))
        if args.tensorboard:
            trainer.extend(
                LogTensorboard(
                    ['lr', 'main/loss', 'main/loss/loc', 'main/loss/conf'],
                    trigger=(10, 'iteration'),
                    log_dir=get_logdir(args.config)))

    if len(cfg.solver.lr_step):
        trainer.extend(
            training.extensions.MultistepShift('lr', 0.1, cfg.solver.lr_step,
                                               cfg.solver.base_lr, optimizer))

    if args.resume:
        serializers.load_npz(args.resume, trainer, strict=False)

    trainer.run()
Exemplo n.º 20
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Transfer Learning Task")
    parser.add_argument("--config",
                        default="",
                        metavar="FILE",
                        help="path to config file",
                        type=str)
    parser.add_argument("--load_ckpt",
                        default=None,
                        metavar="FILE",
                        help="path to the pretrained model to load ",
                        type=str)
    parser.add_argument(
        '--test_only',
        action='store_true',
        help="test the model (need to provide model path to --load_ckpt)")
    parser.add_argument("opts",
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)
    args = parser.parse_args()
    if args.test_only:
        assert os.path.exists(
            args.load_ckpt
        ), "You need to provide the valid model path using --load_ckpt"

    cfg.merge_from_file(args.config)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    np.random.seed(1234)
    torch.manual_seed(1234)
    if 'cuda' in device:
        torch.cuda.manual_seed_all(1234)

    if not os.path.exists(cfg.OUTPUT_DIR):
        print('creating output dir: {}'.format(cfg.OUTPUT_DIR))
        os.makedirs(cfg.OUTPUT_DIR)

    logger = setup_logger("fewshot", cfg.OUTPUT_DIR)
    logger.info("Loaded configuration file {}".format(args.config))
    with open(args.config, "r") as cf:
        config_str = "\n" + cf.read()
        logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    output_config_path = os.path.join(cfg.OUTPUT_DIR, 'config.yml')
    logger.info("Saving config into: {}".format(output_config_path))
    # save overloaded model config in the output directory
    save_config(cfg, output_config_path)

    model = architectures[cfg.MODEL.ARCHITECTURE](cfg, logger)
    model.to(device)

    meta_test_dataset = FashionProductImagesFewShot(cfg=cfg,
                                                    split='test',
                                                    k_way=cfg.TEST.K_WAY,
                                                    n_shot=cfg.TEST.N_SHOT,
                                                    logger=logger)
    meta_test_dataset.params['shuffle'] = False
    meta_test_dataset.params['batch_size'] = 1
    meta_test_dataloader = DataLoader(meta_test_dataset,
                                      **meta_test_dataset.params)

    if args.test_only:
        do_test(cfg, model, meta_test_dataloader, logger, 'test',
                args.load_ckpt)
    else:
        meta_train_dataset = FashionProductImagesFewShot(
            cfg=cfg,
            split='train',
            k_way=cfg.TRAIN.K_WAY,
            n_shot=cfg.TRAIN.N_SHOT,
            logger=logger)
        meta_train_dataloader = DataLoader(meta_train_dataset,
                                           **meta_train_dataset.params)

        do_train(cfg, model, meta_train_dataloader, logger, args.load_ckpt)
        do_test(cfg, model, meta_test_dataloader, logger, 'test', None)
    '''for inner_iter, data in enumerate(meta_test_dataloader):
Exemplo n.º 21
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    args = parser.parse_args()
    cfg.merge_from_file(args.config_file)
    cfg.freeze()

    viewer = Visualizer(cfg.OUTPUT_DIR)
    #Model
    model = build_model(cfg)
    model = DataParallel(model).cuda()
    if cfg.MODEL.WEIGHT !="":
        model.module.backbone.load_state_dict(torch.load(cfg.MODEL.WEIGHT))
        #freeze backbone
        # for key,val in model.module.backbone.named_parameters():
        #     val.requires_grad = False


    #model lr method
    # params_list = []
    # params_list = group_weight(params_list, model.module.backbone,
    #                            nn.BatchNorm2d, cfg.SOLVER.BASE_LR/10)
    # for module in model.module.business:
    #     params_list = group_weight(params_list, module, nn.BatchNorm2d,
    #                                cfg.SOLVER.BASE_LR)


    batch_time = AverageMeter()
    data_time = AverageMeter()
    #optimizer
    optimizer = getattr(torch.optim,cfg.SOLVER.OPTIM)(model.parameters(),lr = cfg.SOLVER.BASE_LR,weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    lr_sche = torch.optim.lr_scheduler.MultiStepLR(optimizer,cfg.SOLVER.STEPS,gamma= cfg.SOLVER.GAMMA)
    #dataset
    datasets  = make_dataset(cfg)
    dataloaders = make_dataloaders(cfg,datasets,True)
    iter_epoch = (cfg.SOLVER.MAX_ITER)//len(dataloaders[0])+1
    if not os.path.exists(cfg.OUTPUT_DIR):
        os.mkdir(cfg.OUTPUT_DIR)
    ite = 0
    batch_it = [i *cfg.SOLVER.IMS_PER_BATCH for i in range(1,4)]


    # start time
    model.train()
    start = time.time()
    for epoch in tqdm.tqdm(range(iter_epoch),desc="epoch"):
        for dataloader in dataloaders:
            for imgs,labels,types in tqdm.tqdm(dataloader,desc="dataloader:"):
                lr_sche.step()
                data_time.update(time.time() - start)

                inputs = torch.cat([imgs[0].cuda(),imgs[1].cuda(),imgs[2].cuda()],dim=0)
                features = model(inputs)
                acc,loss = loss_opts.batch_triple_loss(features,labels,types,size_average=True)
                optimizer.zero_grad()
                loss.backward()

                optimizer.step()
                ite+=1
                # viewer.line("train/loss",loss.item()*100,ite)
                print(acc,loss)
                batch_time.update(time.time() - start)
                start = time.time()

                print('Epoch: [{0}][{1}/{2}]\n'
                      'Time: {data_time.avg:.4f} ({batch_time.avg:.4f})\n'.format(
                    epoch,ite, len(dataloader),
                    data_time=data_time, batch_time=batch_time),
                    flush=True)

        torch.save(model.state_dict(),os.path.join(cfg.OUTPUT_DIR,"{}_{}.pth".format(cfg.MODEL.META_ARCHITECTURE,epoch)))
Exemplo n.º 22
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    args = parser.parse_args()
    cfg.merge_from_file(args.config_file)
    cfg.freeze()

    viewer = Visualizer(cfg.OUTPUT_DIR)
    #Model
    model = build_model(cfg)
    model = DataParallel(model).cuda()
    if cfg.MODEL.WEIGHT !="":
        model.module.backbone.load_state_dict(torch.load(cfg.MODEL.WEIGHT))
        #freeze backbone
        for key,val in model.module.backbone.named_parameters():
            val.requires_grad = False


    batch_time = AverageMeter()
    data_time = AverageMeter()

    #optimizer
    optimizer = getattr(torch.optim,cfg.SOLVER.OPTIM)(model.parameters(),lr = cfg.SOLVER.BASE_LR,weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    lr_sche = torch.optim.lr_scheduler.MultiStepLR(optimizer,cfg.SOLVER.STEPS,gamma= cfg.SOLVER.GAMMA)

    #dataset
    datasets  = make_dataset(cfg,is_train=False)
    dataloaders = make_dataloaders(cfg,datasets,False)
    iter_epoch = (cfg.SOLVER.MAX_ITER)//len(dataloaders[0])+1
    if not os.path.exists(cfg.OUTPUT_DIR):
        os.mkdir(cfg.OUTPUT_DIR)
    ite = 0
    batch_it = [i *cfg.SOLVER.IMS_PER_BATCH for i in range(1,4)]


    # start time
    start = time.time()
    inference_list = ['resnet18_14.pth','resnet18_13.pth','resnet18_12.pth','resnet18_11.pth','resnet18_10.pth']
    for inference_weight in inference_list:
        model.load_state_dict(torch.load(os.path.join(resume_dir,inference_weight)))
        model.eval()

        total_count = 0
        one_count = 0
        two_count = 0
        three_count = 0
        one_number = 0
        two_number = 0
        three_number = 0
        for dataloader in dataloaders:
            for imgs,labels,types in tqdm.tqdm(dataloader,desc="dataloader:"):
                types = np.asarray(types)
                lr_sche.step()
                data_time.update(time.time() - start)

                inputs = torch.cat([imgs[0].cuda(),imgs[1].cuda(),imgs[2].cuda()],dim=0)
                with torch.no_grad():
                    features = model(inputs)
                acc,batch_loss = loss_opts.batch_triple_loss_acc(features,labels,types,size_average=True)
                print(batch_loss)
                xxx
                total_count+= batch_loss.shape[0]-acc

                ONE_CLASS = (batch_loss[np.nonzero(types=='ONE_CLASS_TRIPLET')[0]])
                TWO_CLASS = (batch_loss[np.nonzero(types=='TWO_CLASS_TRIPLET')[0]])
                THREE_CLASS = (batch_loss[np.nonzero(types=='THREE_CLASS_TRIPLET')[0]])
                one_count += ONE_CLASS.shape[0] - torch.nonzero(ONE_CLASS).shape[0]
                two_count += TWO_CLASS.shape[0] - torch.nonzero(TWO_CLASS).shape[0]
                three_count += THREE_CLASS.shape[0] - torch.nonzero(THREE_CLASS).shape[0]
                one_number+=ONE_CLASS.shape[0]
                two_number+=TWO_CLASS.shape[0]
                three_number+=THREE_CLASS.shape[0]
                # viewer.line("train/loss",loss.item()*100,ite)
        print(inference_weight,total_count/(one_number+two_number+three_number),one_count/one_number,two_count/two_number,three_count/three_number)
Exemplo n.º 23
0
def main():
    # Read config settings from yaml file
    args = get_parser().parse_args()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    #create logger
    output_dir = pathlib.Path(cfg.OUTPUT_DIR)
    output_dir.mkdir(exist_ok=True, parents=True)
    logger = setup_logger("U", "logs")
    logger.info("Loaded configuration file {}".format(cfg))

    lr = cfg.SOLVER.LEARN_RATE
    bs = cfg.SOLVER.BATCH_SIZE
    test_bs = cfg.SOLVER.TEST_BATCH_SIZE
    epochs = cfg.TRAINING.EPOCHS

    # sets the matplotlib display backend (most likely not needed)
    # mp.use('TkAgg', force=True)

    data = create_dataset()
    tte_test_dataset, tee_test_dataset = load_test_data()

    #split the training dataset and initialize the data loaders
    print("Train Data length: {}".format(len(data)))
    train_partition = int(cfg.TRAINING.TRAIN_PARTITION * len(data))
    val_partition = len(data) - train_partition

    # split the training dataset and initialize the data loaders
    train_dataset, valid_dataset = torch.utils.data.random_split(
        data, (train_partition, val_partition))
    train_data = DataLoader(train_dataset, batch_size=bs, shuffle=True)
    valid_data = DataLoader(valid_dataset, batch_size=bs, shuffle=True)
    tte_test_data = DataLoader(tte_test_dataset,
                               batch_size=test_bs,
                               shuffle=True)
    tee_test_data = DataLoader(tee_test_dataset,
                               batch_size=test_bs,
                               shuffle=True)

    if cfg.TRAINING.VISUAL_DEBUG:
        plotter.plot_image_and_mask(data, 1)
    xb, yb = next(iter(train_data))
    print(xb.shape, yb.shape)

    # build the Unet2D with default input and output channels as given by config

    unet = Unet2D()
    # logger.info(unet)

    #loss function and optimizer
    loss_fn = nn.CrossEntropyLoss()
    opt = torch.optim.Adam(unet.parameters(), lr=lr)

    #do some training
    if cfg.TESTING.TEST_ONLY:
        run_test(unet, opt, tee_test_data)
        evauate_and_log_results(logger, unet, tte_test_data, tee_test_data)
        return
    else:
        train_loss, valid_loss, valid_dice, valid_dice_per_class, pixel_acc = start_train(
            unet,
            train_data,
            valid_data,
            loss_fn,
            opt,
            dice_multiclass,
            acc_metric,
            epochs=epochs)

    # Evaluate network(f)
    evauate_and_log_results(logger, unet, tte_test_data, tee_test_data)

    if cfg.TRAINING.VISUAL_DEBUG:
        # plot training and validation losses
        plotter.plot_train_and_val_loss(train_loss, valid_loss)

        # plot validation pixel accuracy
        plotter.plot_avg_dice_history(valid_dice)

        # plot validation dice
        plotter.plot_dice_history(valid_dice_per_class)

        # show the predicted segmentations
        plotter.predict_on_batch_and_plot(tte_test_data, unet, test_bs)
        plotter.predict_on_batch_and_plot(tee_test_data, unet, test_bs)