Example #1
0
def train(args):
    logging.info('======= user config ======')
    logging.info(pprint(opt))
    logging.info(pprint(args))
    logging.info('======= end ======')

    train_data, valid_data = get_data_provider(opt)

    net = getattr(network, opt.network.name)(classes=opt.dataset.num_classes)
    optimizer = getattr(torch.optim,
                        opt.train.optimizer)(net.parameters(),
                                             lr=opt.train.lr,
                                             weight_decay=opt.train.wd,
                                             momentum=opt.train.momentum)
    ce_loss = nn.CrossEntropyLoss()
    lr_scheduler = LRScheduler(base_lr=opt.train.lr,
                               step=opt.train.step,
                               factor=opt.train.factor,
                               warmup_epoch=opt.train.warmup_epoch,
                               warmup_begin_lr=opt.train.warmup_begin_lr)
    net = nn.DataParallel(net)
    net = net.cuda()
    mod = Solver(opt, net)
    mod.fit(train_data=train_data,
            test_data=valid_data,
            optimizer=optimizer,
            criterion=ce_loss,
            lr_scheduler=lr_scheduler)
Example #2
0
def train(args):
    logging.info('======= user config ======')
    logging.info(pprint(opt))
    logging.info(pprint(args))
    logging.info('======= end ======')

    train_data, test_data, num_query = get_data_provider(opt)

    net = getattr(network, opt.network.name)(opt.dataset.num_classes,
                                             opt.network.last_stride)
    net = nn.DataParallel(net).cuda()

    optimizer = getattr(torch.optim,
                        opt.train.optimizer)(net.parameters(),
                                             lr=opt.train.lr,
                                             weight_decay=opt.train.wd)
    ce_loss = nn.CrossEntropyLoss()
    triplet_loss = TripletLoss(margin=opt.train.margin)

    def ce_loss_func(scores, feat, labels):
        ce = ce_loss(scores, labels)
        return ce

    def tri_loss_func(scores, feat, labels):
        tri = triplet_loss(feat, labels)[0]
        return tri

    def ce_tri_loss_func(scores, feat, labels):
        ce = ce_loss(scores, labels)
        triplet = triplet_loss(feat, labels)[0]
        return ce + triplet

    if opt.train.loss_fn == 'softmax':
        loss_fn = ce_loss_func
    elif opt.train.loss_fn == 'triplet':
        loss_fn = tri_loss_func
    elif opt.train.loss_fn == 'softmax_triplet':
        loss_fn = ce_tri_loss_func
    else:
        raise ValueError('Unknown loss func {}'.format(opt.train.loss_fn))

    lr_scheduler = LRScheduler(base_lr=opt.train.lr,
                               step=opt.train.step,
                               factor=opt.train.factor,
                               warmup_epoch=opt.train.warmup_epoch,
                               warmup_begin_lr=opt.train.warmup_begin_lr)

    mod = Solver(opt, net)
    mod.fit(train_data=train_data,
            test_data=test_data,
            num_query=num_query,
            optimizer=optimizer,
            criterion=loss_fn,
            lr_scheduler=lr_scheduler)
Example #3
0
def main(args):
    print(args)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)

    solver = Solver(args)

    if args.mode == 'train':
        assert len(subdirs(args.train_img_dir)) == args.num_domains
        assert len(subdirs(args.val_img_dir)) == args.num_domains

        loaders = Munch(src=get_train_loader(root=args.train_img_dir,
                                             which='source',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        ref=get_train_loader(root=args.train_img_dir,
                                             which='reference',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        val=get_test_loader(root=args.val_img_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=True,
                                            num_workers=args.num_workers))
        solver.train(loaders)
    elif args.mode == 'test':
        solver.test()
    elif args.mode == 'sample':
        assert len(subdirs(args.src_dir)) == args.num_domains
        assert len(subdirs(args.ref_dir)) == args.num_domains
        loaders = Munch(src=get_test_loader(root=args.src_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers),
                        ref=get_test_loader(root=args.ref_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers))
        solver.sample(loaders)
    elif args.mode == 'eval':
        solver.evaluate()
    elif args.mode == 'align':
        from core.wing import align_faces
        align_faces(args, args.inp_dir, args.out_dir)
    else:
        raise NotImplementedError
Example #4
0
def test(args):
    logging.info('======= user config ======')
    logging.info(pprint(opt))
    logging.info(pprint(args))
    logging.info('======= end ======')

    train_data, test_data, num_query = get_data_provider(opt)

    net = getattr(network, opt.network.name)(opt.dataset.num_classes, opt.network.last_stride)
    net.load_state_dict(torch.load(args.load_model)['state_dict'])
    net = nn.DataParallel(net).cuda()

    mod = Solver(opt, net)
    mod.test_func(test_data, num_query)
Example #5
0
def main(_):
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.4)

    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:

        if cfg.env_type == 'simple':
            env = SimpleGymEnvironment(cfg)
        else:
            env = GymEnvironment(cfg)

        if not os.path.exists('/tmp/model_dir'):
            os.mkdir('/tmp/model_dir')

        solver = Solver(cfg, env, sess, '/tmp/model_dir')

        solver.train()
Example #6
0
def main():

    solver = Solver()

    with open('imgNames.txt') as f:
        imgName = f.read()

    src = cv2.imread("tmp/uploads/" + imgName)
    ref = cv2.imread("src/stargan/assets/ref_" + str(randrange(7)) + ".jpg")

    res_img = solver.sample(src, ref)
    name_f = ''.join(
        random.choices(string.ascii_uppercase + string.digits, k=15)) + '.jpg'

    cv2.imwrite("tmp/" + name_f, res_img)

    with open('output.txt', 'w+') as f:
        f.write(name_f)
Example #7
0
def main(config):
    # For fast training.
    cudnn.benchmark = True

    # Create directories if not exist.
    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    if not os.path.exists(config.model_save_dir):
        os.makedirs(config.model_save_dir)
    if not os.path.exists(config.sample_dir):
        os.makedirs(config.sample_dir)
    if not os.path.exists(config.result_dir):
        os.makedirs(config.result_dir)

    # Data loader.
    celeba_loader = None
    rafd_loader = None

    if config.dataset in ['CelebA', 'Both']:
        celeba_loader = get_loader(config.celeba_image_dir, config.attr_path, config.selected_attrs,
                                   config.celeba_crop_size, config.image_size, config.batch_size,
                                   'CelebA', config.mode, config.num_workers)
    if config.dataset in ['RaFD', 'Both']:
        rafd_loader = get_loader(config.rafd_image_dir, None, None,
                                 config.rafd_crop_size, config.image_size, config.batch_size,
                                 'RaFD', config.mode, config.num_workers)
    

    # Solver for training and testing StarGAN.
    solver = Solver(celeba_loader, rafd_loader, config)

    if config.mode == 'train':
        if config.dataset in ['CelebA', 'RaFD']:
            solver.train()
        elif config.dataset in ['Both']:
            solver.train_multi()
    elif config.mode == 'test':
        if config.dataset in ['CelebA', 'RaFD']:
            solver.test()
        elif config.dataset in ['Both']:
            solver.test_multi()
Example #8
0
def train_net(cfg):
    net = ReconstructionNet(cfg)
    print('Network definition:')
    print(inspect.getsource(ReconstructionNet.network_definition))

    # Generate the solver
    solver = Solver(cfg, net)

    # Prefetching data processes
    #
    # Create worker and data queue for data processing. For training data, use
    # multiple processes to speed up the loading. For validation data, use 1
    # since the queue will be popped every TRAIN.NUM_VALIDATION_ITERATIONS.
    global train_queue, val_queue, train_processes, val_processes
    train_queue = Queue(cfg.TRAIN.QUEUE_SIZE)
    val_queue = Queue(cfg.TRAIN.QUEUE_SIZE)

    train_processes = make_processes(cfg,
                                     train_queue,
                                     category_model_id_pair(
                                         cfg.DIR.DATASET_TAXONOMY_FILE_PATH,
                                         cfg.DIR.DATASET_QUERY_PATH,
                                         cfg.TRAIN.DATASET_PORTION),
                                     cfg.TRAIN.NUM_WORKER,
                                     repeat=True)
    val_processes = make_processes(cfg,
                                   val_queue,
                                   category_model_id_pair(
                                       cfg.DIR.DATASET_TAXONOMY_FILE_PATH,
                                       cfg.DIR.DATASET_QUERY_PATH,
                                       cfg.TEST.DATASET_PORTION),
                                   1,
                                   repeat=True,
                                   train=False)

    # Train the network
    solver.train(train_queue, val_queue)

    # Cleanup the processes and the queue.
    kill_processes(train_queue, train_processes)
    kill_processes(val_queue, val_processes)
Example #9
0
def train(args):
    train_data, valid_data, train_valid_data = get_data_provider(args.bs)

    net = network.ResNet18(num_classes=10)
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=args.lr,
                                weight_decay=args.wd,
                                momentum=args.momentum)
    ce_loss = nn.CrossEntropyLoss()

    net = nn.DataParallel(net)
    if args.use_gpu:
        net = net.cuda()
    mod = Solver(net, args.use_gpu)
    mod.fit(train_data=train_data,
            test_data=valid_data,
            optimizer=optimizer,
            criterion=ce_loss,
            num_epochs=args.epochs,
            print_interval=args.print_interval,
            eval_step=args.eval_step,
            save_step=args.save_step,
            save_dir=args.save_dir)
Example #10
0
def create_model(args, model_type):
    args.resume_iter = 100000

    if model_type == 'CelebA-HQ':
        args.num_domains = 2
        args.w_hpf = 1
        args.checkpoint_dir = 'expr/checkpoints/celeba_hq'

    else:
        args.num_domains = 3
        args.w_hpf = 0
        args.checkpoint_dir = 'expr/checkpoints/afhq'

    model = Solver(args)
    return model
Example #11
0
def main(args):
    print(args)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)

    solver = Solver(args)

    if args.mode == 'train':
        sintel_path = "/home/tomstrident/datasets/"
        video_id = "temple_2"
        test_loader = getTestDatasetLoader(sintel_path, video_id)
        train_loader, eval_loader = get_loaderFC2(
            args.data_dir, args.style_dir, args.temp_dir, args.batch_size,
            args.num_workers, args.num_domains, args.mode)
        print("start training ...")
        print("args.num_domains:", args.num_domains)
        solver.train([train_loader, test_loader])
    elif args.mode == 'sample':
        assert len(subdirs(args.src_dir)) == args.num_domains
        assert len(subdirs(args.ref_dir)) == args.num_domains
        loaders = Munch(src=get_test_loader(root=args.src_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers),
                        ref=get_test_loader(root=args.ref_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers))
        solver.sample(loaders)
    elif args.mode == 'eval':
        _, eval_loader = get_loaderFC2(args.data_dir, args.style_dir,
                                       args.temp_dir, args.batch_size,
                                       args.num_workers, args.num_domains,
                                       args.mode)
        print("len(eval_loader)", len(eval_loader))
        solver.evaluate(loader=eval_loader)
        #solver.eval_sintel()
    elif args.mode == 'align':
        from core.wing import align_faces
        align_faces(args, args.inp_dir, args.out_dir)
    else:
        raise NotImplementedError
Example #12
0
def main(args):
    print(args)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)

    solver = Solver(args)
    if args.mode == 'train':
        loaders = Munch(src=get_train_loader(root=args.train_img_dir,
                                             which='source',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers,
                                             dataset_dir=args.dataset_dir),
                        ref=get_train_loader(root=args.train_img_dir,
                                             which='reference',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers,
                                             dataset_dir=args.dataset_dir),
                        val=get_val_loader(root=args.val_img_dir,
                                           img_size=args.img_size,
                                           batch_size=args.val_batch_size,
                                           shuffle=True,
                                           num_workers=args.num_workers,
                                           dataset_dir=args.dataset_dir))
        solver.train(loaders)

    elif args.mode == 'sample':  ## 'styling_ref' 모드에 맞게 hyun 추가
        solver.sample()

        # hyun 추가
        #parsing(respth='./results/label/src', dspth= os.path.join(args.src_dir,trg_domain)) # parsing src_image
        #parsing(respth='./results/label/others', dspth= os.path.join(args.result_dir,trg_domain) # parsing fake_image
        #reconstruct() # 'styling' 모드

    elif args.mode == 'eval':
        fid_values, fid_mean = solver.evaluate()
        for key, value in fid_values.items():
            print(key, value)
    else:
        raise NotImplementedError
Example #13
0
def main(args):
    print(args)
    #wandb.init(project="stargan", entity="stacey", config=args, name=args.model_name)
    #cfg = wandb.config
    #cfg.update({"dataset" : "afhq", "type" : "train"})
    cudnn.benchmark = True
    torch.manual_seed(args.seed)

    solver = Solver(args)

    if args.mode == 'train':
        assert len(subdirs(args.train_img_dir)) == args.num_domains
        assert len(subdirs(args.val_img_dir)) == args.num_domains
        loaders = Munch(src=get_train_loader(root=args.train_img_dir,
                                             which='source',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        ref=get_train_loader(root=args.train_img_dir,
                                             which='reference',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        val=get_test_loader(root=args.val_img_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=True,
                                            num_workers=args.num_workers))
        solver.train(loaders)
    elif args.mode == 'sample':
        assert len(subdirs(args.src_dir)) == args.num_domains
        assert len(subdirs(args.ref_dir)) == args.num_domains
        loaders = Munch(src=get_test_loader(root=args.src_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers),
                        ref=get_test_loader(root=args.ref_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers))
        solver.sample(loaders)
    elif args.mode == 'eval':
        solver.evaluate(args)
    elif args.mode == 'align':
        from core.wing import align_faces
        align_faces(args, args.inp_dir, args.out_dir)
    elif args.mode == 'custom':
        # override some default arguments
        wandb.init(project="stargan", config=args, name=args.model_name)
        # src or ref may each be a dir or an image
        # make temporary folders for images
        if os.path.isfile(args.custom_src):
            src_dir = "tmp_src"
            full_src = src_dir + "/src"
            if os.path.exists(src_dir):
                shutil.rmtree(src_dir)
            os.makedirs(full_src)
            shutil.copy2(args.custom_src, full_src)
            src_images = src_dir
        else:
            src_images = args.custom_src
        if os.path.isfile(args.custom_ref):
            ref_dir = "tmp_ref"
            full_ref = ref_dir + "/ref"
            if os.path.exists(ref_dir):
                shutil.rmtree(ref_dir)
            os.makedirs(full_ref)
            shutil.copy2(args.custom_ref, full_ref)
            if args.extend_domain:
                # make some extra domains
                for d in [ref_dir + "/ref2", ref_dir + "/ref3"]:
                    os.makedirs(d)
                    shutil.copy2(args.custom_ref, d)
            ref_images = ref_dir
        else:
            ref_images = args.custom_ref
        loaders = Munch(src=get_test_loader(root=src_images,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers),
                        ref=get_test_loader(root=ref_images,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers))
        solver.custom(loaders)
    else:
        raise NotImplementedError
Example #14
0
def main(args):
    print(args)
    cudnn.benchmark = True
    if args.mode == 'train':
        torch.manual_seed(args.seed)

    solver = Solver(args)

    transform = transforms.Compose([
        transforms.Resize([args.img_size, args.img_size]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    if args.mode == 'train':
        assert len(subdirs(args.train_img_dir)) == args.num_domains
        assert len(subdirs(args.val_img_dir)) == args.num_domains
        if args.resume_iter > 0:
            solver._load_checkpoint(args.resume_iter)
        loaders = Munch(src=get_train_loader(root=args.train_img_dir,
                                             which='source',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        ref=get_train_loader(root=args.train_img_dir,
                                             which='reference',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        val=get_test_loader(root=args.val_img_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=True,
                                            num_workers=args.num_workers))
        solver.train(loaders)
    elif args.mode == 'eval':
        solver.evaluate()

    elif args.mode == 'align':
        from core.wing import align_faces
        align_faces(args, args.inp_dir, args.out_dir)

    elif args.mode == 'inter':  # interpolation
        save_dir = args.save_dir
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

        solver._load_checkpoint(args.resume_iter)
        nets_ema = solver.nets_ema
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        image_name = os.path.basename(args.input)
        image = Variable(
            transform(Image.open(
                args.input).convert('RGB')).unsqueeze(0).to(device))
        masks = nets_ema.fan.get_heatmap(image) if args.w_hpf > 0 else None
        y1 = torch.tensor([args.y1]).long().cuda()
        y2 = torch.tensor([args.y2]).long().cuda()
        outputs = interpolations(nets_ema,
                                 args.latent_dim,
                                 image,
                                 masks,
                                 lerp_step=0.1,
                                 y1=y1,
                                 y2=y2,
                                 lerp_mode=args.lerp_mode)
        path = os.path.join(save_dir, image_name)
        vutils.save_image(outputs.data, path, padding=0)

    elif args.mode == 'test':
        save_dir = args.save_dir
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

        solver._load_checkpoint(args.resume_iter)
        nets_ema = solver.nets_ema

        image_name = os.path.basename(args.input)
        image = Variable(
            transform(Image.open(
                args.input).convert('RGB')).unsqueeze(0)).cuda()
        masks = nets_ema.fan.get_heatmap(image) if args.w_hpf > 0 else None

        image_ref = None
        if args.test_mode == 'reference':
            image_ref = Variable(
                transform(Image.open(
                    args.input_ref).convert("RGB")).unsqueeze(0)).cuda()

        fake = test_single(nets_ema, image, masks, args.latent_dim, image_ref,
                           args.target_domain, args.single_mode)
        fake = torch.clamp(fake * 0.5 + 0.5, 0, 1)
        path = os.path.join(save_dir, image_name)
        vutils.save_image(fake.data, path, padding=0)

    elif args.mode == 'video':
        save_dir = args.save_dir
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

        solver._load_checkpoint(args.resume_iter)
        nets_ema = solver.nets_ema

        image_name = os.path.basename(args.input)
        image = Variable(
            transform(Image.open(
                args.input).convert('RGB')).unsqueeze(0)).cuda()
        masks = nets_ema.fan.get_heatmap(image) if args.w_hpf > 0 else None

        y1 = torch.tensor([args.y1]).long().cuda()
        y2 = torch.tensor([args.y2]).long().cuda()
        outputs = interpolations_loop(nets_ema,
                                      args.latent_dim,
                                      image,
                                      masks,
                                      lerp_step=0.02,
                                      y1=y1,
                                      y2=y2,
                                      lerp_mode=args.lerp_mode)
        outputs = torch.cat(outputs)
        outputs = tensor2ndarray255(outputs)
        path = os.path.join(save_dir, '{}-video.mp4'.format(image_name))
        save_video(path, outputs)

    else:
        raise NotImplementedError
Example #15
0
 def __init__(self):
     Solver.__init__(self, "topn", "fuzz")
Example #16
0
 def __init__(self):
     Solver.__init__(self, "pred", "ub")
Example #17
0
 def __init__(self):
     Solver.__init__(self, "topn", "ib")
     self._tids = TS.instance().ids()
Example #18
0
 def __init__(self):
     Solver.__init__(self, "pred", "fuzz")
Example #19
0
 def __init__(self, coordinator):
     self.output_dir = coordinator.checkpoints_dir()
     self.summary_dir = coordinator.summary_dir()
     self.solver = Solver()
Example #20
0
class TrainWrapper(object):
    def __init__(self, coordinator):
        self.output_dir = coordinator.checkpoints_dir()
        self.summary_dir = coordinator.summary_dir()
        self.solver = Solver()

    def init_weights_fn(self, init_weights):
        if os.path.exists(init_weights):
            logger.info('Loading weights from {}.'.format(init_weights))
            exclusions = ['learning_rate', 'global_step', 'OptimizeLoss',
                          'crnn/LSTMLayers/fully_connected']
            variables = slim.get_variables_to_restore(exclude=exclusions)
            init_fn = slim.assign_from_checkpoint_fn(init_weights, variables, ignore_missing_vars=True)
            return init_fn
        else:
            raise ValueError('Invalid path of weights: {}'.format(init_weights))

    def tower_loss(self, logits, labels, seqlen):
        ctc_loss = tf.nn.ctc_loss(labels=labels,
                                  inputs=logits,
                                  sequence_length=seqlen,
                                  ignore_longer_outputs_than_inputs=True)
        ctc_loss_mean = tf.reduce_mean(ctc_loss)
        # total_loss = tf.add_n([ctc_loss_mean] + tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        tf.add_to_collection('losses', ctc_loss_mean)
        losses = tf.get_collection('losses')
        regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        total_loss = tf.add_n(losses + regularization_losses, name='total_loss')
        return ctc_loss_mean, total_loss

    def metrics(self, logits, labels, seqlen):
        softmax = tf.nn.softmax(logits, dim=-1, name='softmax')
        decoded, log_prob = tf.nn.ctc_greedy_decoder(softmax, seqlen)
        distance = tf.reduce_mean(tf.edit_distance(tf.cast(decoded[0], tf.int32), labels))
        return distance, decoded

    def train(self, data_loader):
        with tf.device('/cpu:0'):
            input_images, input_labels, input_widths = data_loader.read_with_bucket_queue(
                batch_size=cfg.TRAIN.BATCH_SIZE,
                num_threads=cfg.TRAIN.THREADS,
                num_epochs=cfg.TRAIN.EPOCH,
                shuffle=cfg.TRAIN.USE_SHUFFLE)

            images_sp = tf.split(input_images, cfg.NUM_GPUS)
            labels_sp = tf.sparse_split(sp_input=input_labels, num_split=cfg.NUM_GPUS, axis=0)
            widths_sp = tf.split(input_widths, cfg.NUM_GPUS)

            # if cfg.NUM_GPUS > 1:
            #     images_sp = tf.split(input_images, cfg.NUM_GPUS)
            #     labels_sp = tf.sparse_split(sp_input=input_labels, num_split=cfg.NUM_GPUS, axis=0)
            #     widths_sp = tf.split(input_widths, cfg.NUM_GPUS)
            # else:
            #     images_sp = [input_images]
            #     labels_sp = [input_labels]
            #     widths_sp = [input_widths]

            tower_grads = []
            tower_distance = []
            for i, host in enumerate(cfg.HOSTS):
                reuse = i > 0
                with tf.device('/gpu:%d' % host):
                    with tf.name_scope('model_%d' % host) as scope:
                        with tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
                            logits = get_models(cfg.MODEL.BACKBONE)(cfg.MODEL.NUM_CLASSES).build(images_sp[i], True)

                        seqlen = tf.cast(tf.floor_div(widths_sp[i], 2), tf.int32, name='sequence_length')
                        model_loss, total_loss = self.tower_loss(logits, labels_sp[i], seqlen)
                        distance, _ = self.metrics(logits, labels_sp[i], seqlen)

                        if not reuse and cfg.ENABLE_TENSOR_BOARD:
                            tf.summary.image(name='InputImages', tensor=images_sp[i])
                            tf.summary.scalar(name='ModelLoss', tensor=model_loss)
                            tf.summary.scalar(name='TotalLoss', tensor=total_loss)
                            tf.summary.scalar(name='Distance', tensor=distance)

                        self.solver.bn_updates_op = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
                        grads = self.solver.optimizer.compute_gradients(total_loss)

                        tower_grads.append(grads)
                        tower_distance.append(distance)

            self.solver.apply_gradients(tower_grads)
            if cfg.SOLVER.USE_MOVING_AVERAGE_DECAY:
                self.solver.apply_moving_average()

            train_op = self.solver.get_train_op()
            distance_mean = tf.reduce_mean(tower_distance)

            if cfg.ENABLE_TENSOR_BOARD:
                summary_op = tf.summary.merge_all()
                summary_writer = tf.summary.FileWriter(self.summary_dir, tf.get_default_graph())

        saver = tf.train.Saver(tf.global_variables())
        gpu_options = tf.GPUOptions(allow_growth=True)
        with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)) as sess:
            init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
            sess.run(init_op)
            if cfg.RESTORE:
                logger.info('continue training from previous checkpoint')
                ckpt = tf.train.latest_checkpoint(self.output_dir)
                logger.debug(ckpt)
                saver.restore(sess, ckpt)
            else:
                # Load the pre-trained weights
                if cfg.TRAIN.WEIGHTS:
                    self.init_weights_fn(cfg.TRAIN.WEIGHTS)(sess)

            start = time.time()
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            try:
                while not coord.should_stop():
                    _, step, lr, ml, tl, dt = sess.run(
                        [train_op, self.solver.global_step, self.solver.learning_rate, model_loss, total_loss,
                         distance_mean])

                    if np.isnan(tl):
                        logger.error('Loss diverged, stop training')
                        break

                    if step % cfg.SOLVER.DISPLAY == 0:
                        avg_time_per_step = (time.time() - start) / 10
                        avg_examples_per_second = (10 * cfg.TRAIN.BATCH_SIZE * cfg.NUM_GPUS) / (time.time() - start)
                        start = time.time()
                        tb = PrettyTable(
                            ['Step', 'LR', 'ModelLoss', 'TotalLoss', 'sec/step', 'exp/sec', 'Distance'])
                        tb.add_row(
                            ['{}/{}'.format(step, cfg.SOLVER.MAX_ITERS), '{:.3f}'.format(lr), '{:.3f}'.format(ml),
                             '{:.3f}'.format(tl),
                             '{:.3f}'.format(avg_time_per_step),
                             '{:.3f}'.format(avg_examples_per_second), '{:.3f}'.format(dt)])
                        print(tb)

                        if cfg.ENABLE_TENSOR_BOARD:
                            summary_str = sess.run([summary_op, ])
                            summary_writer.add_summary(summary_str[0], global_step=step)

                    if step % cfg.SOLVER.SNAPSHOT_ITERS == 0:
                        saver.save(sess, os.path.join(self.output_dir, 'model.ckpt'),
                                   global_step=self.solver.global_step)

                    if step >= cfg.SOLVER.MAX_ITERS:
                        break

            except tf.errors.OutOfRangeError:
                logger.info('Epochs Complete!')
            finally:
                coord.request_stop()
            coord.join(threads)
Example #21
0
def main(config):
    # log file
    log_dir = "./log"
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s %(name)s %(levelname)s %(message)s',
        datefmt='%m-%d %H:%M',
        filename='{}/{}.log'.format(log_dir, config.model_prefix),
        filemode='a')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(name)s %(levelname)s %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)
    # model folder
    model_dir = "./model"
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)

    # set up environment
    devs = [mx.gpu(int(i)) for i in config.gpu_list]
    kv = mx.kvstore.create(config.kv_store)

    # set up iterator and symbol
    # iterator
    train, val, num_examples = imagenet_iterator(data_dir=config.data_dir,
                                                 batch_size=config.batch_size,
                                                 kv=kv)
    data_names = ('data', )
    label_names = ('softmax_label', )
    data_shapes = [('data', (config.batch_size, 3, 224, 224))]
    label_shapes = [('softmax_label', (config.batch_size, ))]

    symbol = eval(config.network)(num_classes=config.num_classes,
                                  config=config)

    # train
    epoch_size = max(int(num_examples / config.batch_size / kv.num_workers), 1)
    if config.lr_step is not None:
        lr_scheduler = multi_factor_scheduler(config.begin_epoch,
                                              epoch_size,
                                              step=config.lr_step,
                                              factor=config.lr_factor)
    else:
        lr_scheduler = None

    optimizer_params = {
        'learning_rate': config.lr,
        'lr_scheduler': lr_scheduler,
        'wd': config.wd,
        'momentum': config.momentum
    }
    optimizer = "nag"
    eval_metric = ['acc']
    if config.dataset == "imagenet":
        eval_metric.append(mx.metric.create('top_k_accuracy', top_k=5))

    # knowledge transfer
    teacher_module = None
    if config.kt:
        eval_metric = [KTAccMetric()]
        if config.dataset == 'imagenet':
            eval_metric.append(KTTopkAccMetric(top_k=5))
        if len(config.kt_type.split('+')) > 1:
            logging.info(
                'knowledge transfer training by {} with weight {}'.format(
                    config.kt_type.split('+')[0], config.kt_weight[0]))
            logging.info(
                'knowledge transfer training by {} with weight {}'.format(
                    config.kt_type.split('+')[1], config.kt_weight[1]))
        else:
            logging.info(
                'knowledge transfer training by {} with weight {}'.format(
                    config.kt_type, config.kt_weight))
        label_names += config.kt_label_names
        label_shapes += config.kt_label_shapes
        logging.info('loading teacher model from {}-{:04d}'.format(
            config.teacher_prefix, config.teacher_epoch))
        teacher_symbol, teacher_arg_params, teacher_aux_params = mx.model.load_checkpoint(
            config.teacher_prefix, config.teacher_epoch)
        if len(config.kt_type.split('+')) > 1:
            teacher_symbol = mx.symbol.Group([
                teacher_symbol.get_internals()[config.teacher_symbol[0]],
                teacher_symbol.get_internals()[config.teacher_symbol[1]]
            ])
        else:
            teacher_symbol = teacher_symbol.get_internals()[
                config.teacher_symbol]
        teacher_module = mx.module.Module(teacher_symbol, context=devs)
        teacher_module.bind(data_shapes=data_shapes,
                            for_training=False,
                            grad_req='null')
        teacher_module.set_params(teacher_arg_params, teacher_aux_params)

    solver = Solver(symbol=symbol,
                    data_names=data_names,
                    label_names=label_names,
                    data_shapes=data_shapes,
                    label_shapes=label_shapes,
                    logger=logging,
                    context=devs)
    epoch_end_callback = mx.callback.do_checkpoint("./model/" +
                                                   config.model_prefix)
    batch_end_callback = mx.callback.Speedometer(config.batch_size,
                                                 config.frequent)
    arg_params = None
    aux_params = None
    if config.retrain:
        logging.info('retrain from {}-{:04d}'.format(config.model_load_prefix,
                                                     config.model_load_epoch))
        _, arg_params, aux_params = mx.model.load_checkpoint(
            "model/{}".format(config.model_load_prefix),
            config.model_load_epoch)
    initializer = mx.init.Xavier(rnd_type='gaussian',
                                 factor_type='in',
                                 magnitude=2)

    solver.fit(train_data=train,
               eval_data=val,
               eval_metric=eval_metric,
               epoch_end_callback=epoch_end_callback,
               batch_end_callback=batch_end_callback,
               initializer=initializer,
               arg_params=arg_params,
               aux_params=aux_params,
               optimizer=optimizer,
               optimizer_params=optimizer_params,
               begin_epoch=config.begin_epoch,
               num_epoch=config.num_epoch,
               kvstore=kv,
               teacher_modules=teacher_module)
Example #22
0
def test_net(cfg):
    # Set batch size to 1
    cfg.CONST.BATCH_SIZE = 1

    # Generate network and solver
    net = ReconstructionNet(cfg)
    print('[INFO] %s Loading network parameters from %s' % (dt.now(), cfg.CONST.WEIGHTS))
    net.load(cfg.CONST.WEIGHTS)
    
    solver = Solver(cfg, net)

    # Setup results container
    results = dict()
    results['samples']    = { t: [] for t in cfg.TEST.VOXEL_THRESH } # Result of each sample
    results['categories'] = { t: {} for t in cfg.TEST.VOXEL_THRESH } # Result of each category

    # Set up testing data process. We make only one prefetching process. The
    # process will return one batch at a time.
    queue = Queue(cfg.TRAIN.QUEUE_SIZE)
    data_pair = category_model_id_pair(cfg.DIR.DATASET_TAXONOMY_FILE_PATH, \
                                        cfg.DIR.DATASET_QUERY_PATH , \
                                        cfg.TEST.DATASET_PORTION)
    processes = make_processes(cfg, queue, data_pair, 1, repeat=False, train=False)
    n_test_data = len(processes[0].data_paths)

    # Main Test Process
    batch_idx  = 0
    categories = []

    for imgs, voxel in get_while_running(processes[0], queue):
        if batch_idx == n_test_data:
            break                       # Reach the end of the test sample + 1

        c_id = data_pair[batch_idx][0] # Category of the sample
        s_id = data_pair[batch_idx][1] # ID of the sample
        prediction, loss, activations = solver.test_output(imgs, voxel)

        # Normalize voxel to 0-1
        voxel = voxel > 0

        ious = []
        for threshold in cfg.TEST.VOXEL_THRESH:
            iou = get_iou(prediction[0, :, 1, :, :], voxel[0, :, 1, :, :], threshold)

            if not c_id in results['categories'][threshold]:
                results['categories'][threshold][c_id] = []
                if not c_id in categories:
                    categories.append(c_id)


            results['samples'][threshold].append(iou)
            results['categories'][threshold][c_id].append(iou)
            ious.append(iou)

        # Print test result of the sample
        print('[INFO] %04d/%d\tID: %s\tCategory: %s\tIoU: %s' % (batch_idx + 1, n_test_data, s_id, c_id, ious))
        batch_idx +=1

    # Print summarized results
    print('\n\n================ Test Report ================')

    print('Category ID', end='\t')
    for t in cfg.TEST.VOXEL_THRESH:
        print('%g' % t, end='\t')
    print()

    for c in categories:
        print(c, end='\t')
        for t in cfg.TEST.VOXEL_THRESH:
            print('%.4f' % np.mean(results['categories'][t][c]), end='\t')
        print()
        
    print('Mean', end='\t\t')
    for t in cfg.TEST.VOXEL_THRESH:
        print('%.4f' % np.mean(results['samples'][t]), end='\t')
    print()

    # Cleanup the processes and the queue.
    kill_processes(queue, processes)
Example #23
0
def main(config):
    output_dir = "experiments/" + config.output_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s %(name)s %(levelname)s %(message)s',
                        datefmt='%m-%d %H:%M',
                        filename='{}/{}.log'.format(output_dir, config.model_prefix),
                        filemode='a')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(name)s %(levelname)s %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)

    logging.info(config)

    # set up environment
    devs = [mx.gpu(int(i)) for i in config.gpu_list]
    kv = mx.kvstore.create(config.kv_store)

    # set up iterator and symbol
    # iterator
    if config.use_multiple_iter is True:
        train, val, num_examples = multiple_imagenet_iterator(data_dir=config.data_dir,
                                                    batch_size=config.batch_size,
                                                    num_parts=2,
                                                    image_shape=tuple(config.image_shape),
                                                    data_nthread=config.data_nthreads)
    # elif config.use_dali_iter is True:
    #     train, val, num_examples = get_dali_iter(data_dir=config.data_dir,
    #                                              batch_size=config.batch_size,
    #                                              kv=kv,
    #                                              image_shape=tuple(config.image_shape),
    #                                              num_gpus=len(devs))
    else:
        if config.dataset == 'imagenet':
            train, val, num_examples = imagenet_iterator(data_dir=config.data_dir,
                                                         batch_size=config.batch_size,
                                                         kv=kv,
                                                         image_shape=tuple(config.image_shape))
        elif config.dataset == 'cifar10':
            train, val, num_examples = cifar10_iterator(data_dir=config.data_dir,
                                                         batch_size=config.batch_size,
                                                         kv=kv,
                                                         image_shape=tuple(config.image_shape))
        elif config.dataset == 'cifar100':
            train, val, num_examples = cifar100_iterator(data_dir=config.data_dir,
                                                         batch_size=config.batch_size,
                                                         kv=kv,
                                                         image_shape=tuple(config.image_shape))
    logging.info(train)
    logging.info(val)

    data_names = ('data',)
    label_names = ('softmax_label',)
    data_shapes = [('data', tuple([config.batch_size] + config.image_shape))]
    label_shapes = [('softmax_label', (config.batch_size,))]

    if config.network in ['resnet', 'resnet_cifar10', 'preact_resnet']:
        symbol = eval(config.network)(units=config.units,
                                      num_stage=config.num_stage,
                                      filter_list=config.filter_list,
                                      num_classes=config.num_classes,
                                      data_type=config.data_type,
                                      bottle_neck=config.bottle_neck,
                                      grad_scale=config.grad_scale,
                                      memonger=config.memonger,
                                      dataset_type=config.dataset)
    elif config.network == 'resnet_mxnet':
        symbol = eval(config.network)(units=config.units,
                                      num_stage=config.num_stage,
                                      filter_list=config.filter_list,
                                      num_classes=config.num_classes,
                                      data_type=config.data_type,
                                      bottle_neck=config.bottle_neck,
                                      grad_scale=config.grad_scale)
    elif config.network == 'resnext' or config.network == 'resnext_cyt':
        symbol = eval(config.network)(units=config.units,
                                      num_stage=config.num_stage,
                                      filter_list=config.filter_list,
                                      num_classes=config.num_classes,
                                      data_type=config.data_type,
                                      num_group=config.num_group,
                                      bottle_neck=config.bottle_neck)
    elif config.network == 'vgg16' or config.network == 'mobilenet' or config.network == 'shufflenet':
        symbol = eval(config.network)(num_classes=config.num_classes)
    elif config.network == "cifar10_sym":
        symbol = eval(config.network)()


    if config.fix_bn:
        from core.graph_optimize import fix_bn
        print("********************* fix bn ***********************")
        symbol = fix_bn(symbol)

    if config.quantize_flag:
        assert config.data_type == "float32", "current quantization op only support fp32 mode."
        from core.graph_optimize import attach_quantize_node
        worker_data_shape = dict(data_shapes + label_shapes)
        _, out_shape, _ = symbol.get_internals().infer_shape(**worker_data_shape)
        out_shape_dictoinary = dict(zip(symbol.get_internals().list_outputs(), out_shape))
        symbol = attach_quantize_node(symbol, out_shape_dictoinary, 
                                      config.quantize_setting["weight"], config.quantize_setting["act"], 
                                      config.quantized_op, config.skip_quantize_counts)
        # symbol.save("attach_quant.json")
        # raise NotImplementedError

    # symbol.save(config.network + ".json")
    # raise NotImplementedError
    # mx.viz.print_summary(symbol, {'data': (1, 3, 224, 224)})

    # memonger
    if config.memonger:
        # infer shape
        data_shape_dict = dict(train.provide_data + train.provide_label)
        per_gpu_data_shape_dict = {}
        for key in data_shape_dict:
            per_gpu_data_shape_dict[key] = (config.batch_per_gpu,) + data_shape_dict[key][1:]

        # if config.network == 'resnet':
        #     last_block = 'conv3_1_relu'
        #     if kv.rank == 0:
        #         logging.info("resnet do memonger up to {}".format(last_block))
        # else:
        #     last_block = None
        last_block = 'stage4_unit1_sc'
        input_dtype = {k: 'float32' for k in per_gpu_data_shape_dict}
        symbol = search_plan_to_layer(symbol, last_block, 1000, type_dict=input_dtype, **per_gpu_data_shape_dict)

    # train
    epoch_size = max(int(num_examples / config.batch_size / kv.num_workers), 1)
    if 'dist' in config.kv_store and not 'async' in config.kv_store \
            and config.use_multiple_iter is False and config.use_dali_iter is False:
        logging.info('Resizing training data to %d batches per machine {}'.format(epoch_size))
        # resize train iter to ensure each machine has same number of batches per epoch
        # if not, dist_sync can hang at the end with one machine waiting for other machines
        train = mx.io.ResizeIter(train, epoch_size)

    if config.warmup is not None and config.warmup is True:
        lr_epoch = [int(epoch) for epoch in config.lr_step]
        lr_epoch_diff = [epoch - config.begin_epoch for epoch in lr_epoch if epoch > config.begin_epoch]
        lr = config.lr * (config.lr_factor ** (len(lr_epoch) - len(lr_epoch_diff)))
        lr_iters = [int(epoch * epoch_size) for epoch in lr_epoch_diff]
        logging.info('warmup lr:{}, warm_epoch:{}, warm_step:{}, '.format(
            config.warmup_lr, config.warm_epoch, int(config.warm_epoch * epoch_size)))
        if config.lr_scheduler == 'poly':
            logging.info('PolyScheduler lr'.format(lr))
            lr_scheduler = mx.lr_scheduler.PolyScheduler(int(epoch_size*config.num_epoch), base_lr=lr, pwr=2, final_lr=0,
                                                         warmup_steps=int(config.warm_epoch * epoch_size),
                                                         warmup_begin_lr=0, warmup_mode='linear')
        else:
            logging.info('WarmupMultiFactorScheduler lr:{}, epoch size:{}, lr_epoch_diff:{}, '
                         'lr_iters:{}'.format( lr, epoch_size, lr_epoch_diff, lr_iters))
            lr_scheduler = WarmupMultiFactorScheduler(base_lr=lr, step=lr_iters, factor=config.lr_factor,
                                                  warmup=True, warmup_type='gradual',
                                                  warmup_lr=config.warmup_lr, warmup_step=int(config.warm_epoch * epoch_size))
    elif config.lr_step is not None:
        lr_epoch_diff = [epoch - config.begin_epoch for epoch in config.lr_step if epoch > config.begin_epoch]
        lr = config.lr * (config.lr_factor **(len(config.lr_step) - len(lr_epoch_diff)))
        lr_scheduler = multi_factor_scheduler(config.begin_epoch, epoch_size, step=config.lr_step,
                                              factor=config.lr_factor)
        step_ = [epoch * epoch_size for epoch in lr_epoch_diff]
        logging.info('multi_factor_scheduler lr:{}, epoch size:{}, epoch diff:{}, '
                     'step:{}'.format(lr, epoch_size, lr_epoch_diff, step_))
    else:
        lr = config.lr
        lr_scheduler = None
    print("begin epoch:{}, num epoch:{}".format(config.begin_epoch, config.num_epoch))

    optimizer_params = {
        'learning_rate': lr,
        'wd': config.wd,
        'lr_scheduler': lr_scheduler,
        'multi_precision': config.multi_precision}
    # Only a limited number of optimizers have 'momentum' property
    has_momentum = {'sgd', 'dcasgd', 'nag', 'signum', 'lbsgd'}
    if config.optimizer in has_momentum:
        optimizer_params['momentum'] = config.momentum
    # A limited number of optimizers have a warmup period
    has_warmup = {'lbsgd', 'lbnag'}
    if config.optimizer in has_warmup:
        optimizer_params['updates_per_epoch'] = epoch_size
        optimizer_params['begin_epoch'] = config.begin_epoch
        optimizer_params['batch_scale'] = 1.0
        optimizer_params['warmup_strategy'] = 'lars'
        optimizer_params['warmup_epochs'] = config.warm_epoch # not work whne warmup_strategy is 'lars'
        optimizer_params['num_epochs'] = config.num_epoch

    eval_metric = ['acc']
    if config.dataset == "imagenet":
        eval_metric.append(mx.metric.create('top_k_accuracy', top_k=5))

    solver = Solver(symbol=symbol,
                    data_names=data_names,
                    label_names=label_names,
                    data_shapes=data_shapes,
                    label_shapes=label_shapes,
                    logger=logging,
                    context=devs,
                    # for evaluate fold bn
                    config=config)
    epoch_end_callback = mx.callback.do_checkpoint(os.path.join(output_dir, config.model_prefix))
    batch_end_callback = mx.callback.Speedometer(config.batch_size, config.frequent)
    # batch_end_callback = DetailSpeedometer(config.batch_size, config.frequent)
    initializer = mx.init.Xavier(rnd_type='gaussian', factor_type='in', magnitude=2)
    arg_params = None
    aux_params = None
    if config.retrain:
        print('******************** retrain load pretrain model from: {}'.format(config.model_load_prefix))
        _, arg_params, aux_params = mx.model.load_checkpoint("{}".format(config.model_load_prefix),
                                                             config.model_load_epoch)
    solver.fit(train_data=train,
               eval_data=val,
               eval_metric=eval_metric,
               epoch_end_callback=epoch_end_callback,
               batch_end_callback=batch_end_callback,
               initializer=initializer,
               arg_params=arg_params,
               aux_params=aux_params,
               optimizer=config.optimizer,
               optimizer_params=optimizer_params,
               begin_epoch=config.begin_epoch,
               num_epoch=config.num_epoch,
               kvstore=kv,
               allow_missing=config.allow_missing)
Example #24
0
def main(args):
    print(args)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)

    solver = Solver(args)

    #create csv file
    with open(args.loss_csv_path, 'wb') as csvfile:
        filewriter = csv.writer(csvfile,
                                delimiter=',',
                                quotechar='|',
                                quoting=csv.QUOTE_MINIMAL)

    with open(args.loss_csv_path, 'a') as file:
        writer = csv.writer(file)
        writer.writerow([
            "epoch", "d_loss_z_trg", "d_loss_x_ref", "g_loss_z_trg",
            "g_loss_x_ref"
        ])

    if args.mode == 'train':
        assert len(subdirs(args.train_img_dir)) == args.num_domains
        assert len(subdirs(args.val_img_dir)) == args.num_domains
        loaders = Munch(src=get_train_loader(root=args.train_img_dir,
                                             which='source',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        src_skt=get_train_loader(
                            root=args.train_sketch_img_dir,
                            which='source',
                            img_size=args.img_size,
                            batch_size=args.batch_size,
                            prob=args.randcrop_prob,
                            num_workers=args.num_workers),
                        ref=get_train_loader(root=args.train_img_dir,
                                             which='reference',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        val=get_test_loader(root=args.val_img_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=True,
                                            num_workers=args.num_workers))
        solver.train(loaders)
    elif args.mode == 'sample':
        assert len(subdirs(args.src_dir)) == args.num_domains
        assert len(subdirs(args.ref_dir)) == args.num_domains
        loaders = Munch(src=get_test_loader(root=args.src_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=True,
                                            num_workers=args.num_workers),
                        ref=get_test_loader(root=args.ref_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=True,
                                            num_workers=args.num_workers))
        solver.sample(loaders)
    elif args.mode == 'eval':
        solver.evaluate()
    elif args.mode == 'align':
        from core.wing import align_faces
        align_faces(args, args.inp_dir, args.out_dir)
    else:
        raise NotImplementedError
Example #25
0
def main(config):
    # log file
    log_dir = "./log"
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s %(name)s %(levelname)s %(message)s',
        datefmt='%m-%d %H:%M',
        filename='{}/{}.log'.format(log_dir, config.model_prefix),
        filemode='a')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(name)s %(levelname)s %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)
    # model folder
    model_dir = "./model"
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)

    # set up environment
    devs = [mx.gpu(int(i)) for i in config.gpu_list]
    kv = mx.kvstore.create(config.kv_store)

    # set up iterator and symbol
    # iterator
    train, val, num_examples = imagenet_iterator(data_dir=config.data_dir,
                                                 batch_size=config.batch_size,
                                                 kv=kv)
    print train
    print val
    data_names = ('data', )
    label_names = ('softmax_label', )
    data_shapes = [('data', (config.batch_size, 3, 224, 224))]
    label_shapes = [('softmax_label', (config.batch_size, ))]

    if config.network == 'resnet' or config.network == 'resnext':
        symbol = eval(config.network)(units=config.units,
                                      num_stage=config.num_stage,
                                      filter_list=config.filter_list,
                                      num_classes=config.num_classes,
                                      data_type=config.dataset,
                                      bottle_neck=config.bottle_neck)
    elif config.network == 'vgg16' or config.network == 'mobilenet' or config.network == 'shufflenet':
        symbol = eval(config.network)(num_classes=config.num_classes)

    # train
    epoch_size = max(int(num_examples / config.batch_size / kv.num_workers), 1)
    if config.lr_step is not None:
        lr_scheduler = multi_factor_scheduler(config.begin_epoch,
                                              epoch_size,
                                              step=config.lr_step,
                                              factor=config.lr_factor)
    else:
        lr_scheduler = None

    optimizer_params = {
        'learning_rate': config.lr,
        'lr_scheduler': lr_scheduler,
        'wd': config.wd,
        'momentum': config.momentum
    }
    optimizer = "nag"
    #optimizer = 'sgd'
    eval_metric = ['acc']
    if config.dataset == "imagenet":
        eval_metric.append(mx.metric.create('top_k_accuracy', top_k=5))

    solver = Solver(symbol=symbol,
                    data_names=data_names,
                    label_names=label_names,
                    data_shapes=data_shapes,
                    label_shapes=label_shapes,
                    logger=logging,
                    context=devs)
    epoch_end_callback = mx.callback.do_checkpoint("./model/" +
                                                   config.model_prefix)
    batch_end_callback = mx.callback.Speedometer(config.batch_size,
                                                 config.frequent)
    initializer = mx.init.Xavier(rnd_type='gaussian',
                                 factor_type='in',
                                 magnitude=2)
    arg_params = None
    aux_params = None
    if config.retrain:
        _, arg_params, aux_params = mx.model.load_checkpoint(
            "model/{}".format(config.model_load_prefix),
            config.model_load_epoch)
    solver.fit(train_data=train,
               eval_data=val,
               eval_metric=eval_metric,
               epoch_end_callback=epoch_end_callback,
               batch_end_callback=batch_end_callback,
               initializer=initializer,
               arg_params=arg_params,
               aux_params=aux_params,
               optimizer=optimizer,
               optimizer_params=optimizer_params,
               begin_epoch=config.begin_epoch,
               num_epoch=config.num_epoch,
               kvstore=kv)
Example #26
0
def main(config):
    # log file
    log_dir = "./log"
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s %(name)s %(levelname)s %(message)s',
        datefmt='%m-%d %H:%M',
        filename='{}/{}.log'.format(log_dir, config.model_prefix),
        filemode='a')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(name)s %(levelname)s %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)
    # model folder
    model_dir = "./model"
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)

    # set up environment
    devs = [mx.gpu(int(i)) for i in config.gpu_list]
    kv = mx.kvstore.create(config.kv_store)

    # set up iterator and symbol
    # iterator
    # data list
    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    # ModelNet40 official train/test split
    TRAIN_FILES = provider.getDataFiles( \
        os.path.join(BASE_DIR, config.train_files))
    TEST_FILES = provider.getDataFiles( \
        os.path.join(BASE_DIR, config.val_files))
    train, val, num_examples = dummy_iterator(config.batch_size,
                                              config.num_points, TRAIN_FILES,
                                              TEST_FILES)

    data_names = ('data', )
    label_names = ('softmax_label', )
    data_shapes = [('data', (config.batch_size, config.num_points, 3))]
    label_shapes = [('softmax_label', (config.batch_size, ))]

    if config.network == 'pointnet':
        symbol = eval(config.network)(num_classes=config.num_classes,
                                      batch_size=config.batch_size /
                                      len(config.gpu_list),
                                      num_points=config.num_points)
    # train
    lr_scheduler = mx.lr_scheduler.FactorScheduler(step=int(200000 /
                                                            config.batch_size),
                                                   factor=config.lr_factor,
                                                   stop_factor_lr=1e-05)

    optimizer_params = {
        'learning_rate': config.lr,
        'lr_scheduler': lr_scheduler
    }
    optimizer = "adam"

    eval_metrics = mx.metric.CompositeEvalMetric()
    if config.dataset == "modelnet40":
        for m in [AccMetric, MatLossMetric]:
            eval_metrics.add(m())

    solver = Solver(symbol=symbol,
                    data_names=data_names,
                    label_names=label_names,
                    data_shapes=data_shapes,
                    label_shapes=label_shapes,
                    logger=logging,
                    context=devs)
    epoch_end_callback = mx.callback.do_checkpoint("./model/" +
                                                   config.model_prefix)
    batch_end_callback = mx.callback.Speedometer(config.batch_size,
                                                 config.frequent)
    initializer = mx.init.Xavier(rnd_type='gaussian',
                                 factor_type='in',
                                 magnitude=2)
    arg_params = None
    aux_params = None
    if config.retrain:
        _, arg_params, aux_params = mx.model.load_checkpoint(
            "model/{}".format(config.model_load_prefix),
            config.model_load_epoch)
    solver.fit(train_data=train,
               eval_data=val,
               eval_metric=eval_metrics,
               epoch_end_callback=epoch_end_callback,
               batch_end_callback=batch_end_callback,
               initializer=initializer,
               arg_params=arg_params,
               aux_params=aux_params,
               optimizer=optimizer,
               optimizer_params=optimizer_params,
               begin_epoch=config.begin_epoch,
               num_epoch=config.num_epoch,
               kvstore=kv)