コード例 #1
0
def main(args):
    print('making dataloader...')
    class_len, train_loader, test_loader, raw_loader, iv_part_list, test_imgs = get_dataloader(
        args)
    gpu_count = args.gpu if args.gpu > 0 else 1
    gpus = list(range(torch.cuda.device_count()))
    gpus = gpus[:gpu_count]
    opt_weight = None

    in_channels = 3 if args.color else 1
    if args.vgg:
        model = vgg16_bn(num_classes=class_len)
    elif args.resnet:
        model = resnet50(num_classes=class_len)
    elif args.old:
        model = se_resnext50(num_classes=class_len, input_channels=in_channels)
    else:
        model = multi_serx50(class_list=iv_part_list,
                             input_channels=in_channels)

    if args.resume_epoch != 0:
        with open(args.load_path, 'rb') as f:
            loaded = torch.load(f)
            network_weight = loaded['weight']
            if 'optimizer' in loaded:
                opt_weight = loaded['optimizer']
            else:
                opt_weight = None
        load_weight(model, network_weight)
        last_epoch = args.resume_epoch
    else:
        last_epoch = -1

    model_par = nn.DataParallel(model, device_ids=gpus)
    optimizer = optim.SGD(params=model_par.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.decay)
    if opt_weight:
        load_weight(optimizer, opt_weight)
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          args.lr_step,
                                          gamma=args.lr_gamma,
                                          last_epoch=last_epoch)

    print(f'training params: {args}')
    print('setting trainer...')
    trainer = Trainer(model_par,
                      optimizer,
                      save_dir=args.out_dir,
                      test_imgs=test_imgs)

    print(f'start loop')
    trainer.loop(args,
                 args.epoch,
                 train_loader,
                 test_loader,
                 raw_loader,
                 scheduler,
                 do_save=True)
コード例 #2
0
ファイル: evaluate.py プロジェクト: MerHS/sketch-i2v
def get_network(args, model_path, class_len):
    with open(model_path, 'rb') as f:
        network_weight = torch.load(f)['weight']

    in_channels = 3 if args.color else 1

    if args.vgg:
        network = vgg11_bn(num_classes=class_len, in_channels=in_channels)
    else:
        network = se_resnext50(num_classes=class_len,
                               input_channels=in_channels)

    load_weight(network, network_weight)

    if args.gpu > 0:
        network = network.cuda()

    network.eval()
    return network
コード例 #3
0
    p.add_argument("--blend", action='store_true')
    args = p.parse_args()

    if not Path(args.file_name).exists():
        raise Exception(f"{args.file_name} does not exists.")

    with open('taglist/tag_dump.pkl', 'rb') as f:
        pkl = pickle.load(f)
        iv_tag_list = pkl['iv_tag_list']
        tag_dict = pkl['tag_dict']
        tag_dict = {v: k for k, v in tag_dict.items()}

    with open(args.train_file, 'rb') as f:
        network_weight = torch.load(f)['weight']

    network = se_resnext50(num_classes=len(iv_tag_list), input_channels=1)
    load_weight(network, network_weight)

    network.eval()

    img = cv2.imread(args.file_name)
    img, _, _ = make_square(img, size=512, extend=(True, True))
    cv2.imshow("main", img)
    if args.sketch:
        sketch_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    elif args.blend:
        from sketchify.sketchify import get_sketch
        sketch_img = get_sketch(img, blend=0.15)
        cv2.imshow("sketch", sketch_img)
    else:
        from sketchify.sketchify import get_keras_high_intensity
コード例 #4
0
ファイル: visualize.py プロジェクト: MerHS/sketch-i2v
    p = argparse.ArgumentParser()
    p.add_argument("--load_path", default="result.pth")
    p.add_argument("--tag_dump", default=TAG_FILE_PATH)
    p.add_argument("--old", action="store_true")
    p.add_argument("--part", type=int, default=0)
    p.add_argument("img_path")
    args = p.parse_args()

    tag_dump = class_info(args.tag_dump)
    tag_list, idx_dict, tag_part_list, tag_rev_dict = tag_dump
    class_len = len(tag_list)
    in_channels = 1

    if args.old:
        model = se_resnext50(num_classes=class_len, input_channels=in_channels)
    else:
        model = multi_serx50(class_list=tag_part_list,
                             input_channels=in_channels)

    with open(args.load_path, 'rb') as f:
        loaded = torch.load(f)
        network_weight = loaded['weight']
        if 'optimizer' in loaded:
            opt_weight = loaded['optimizer']
        else:
            opt_weight = None
        load_weight(model, network_weight)

    model.eval()
    finalconv_name = 'layer4'