示例#1
0
def train(args):
    writer = SummaryWriter()
    logger = make_logger(args.log_file)

    if args.zs:
        packed = args.packed_pkl_zs
    else:
        packed = args.packed_pkl_nozs

    data = ZSIH_dataloader(args.sketch_dir, args.image_dir, args.stats_file, args.embedding_file, packed, zs=args.zs)
    print(len(data))
    dataloader_train = DataLoader(dataset=data, num_workers=args.num_worker, \
                                  batch_size=args.batch_size,
                                  shuffle=args.shuffle)

    logger.info('Building the model ...')
    model = ZSIM(args.hidden_size, args.hashing_bit, args.semantics_size, data.pretrain_embedding.float(), 
                 adj_scaler=args.adj_scaler, dropout=args.dropout, fix_cnn=args.fix_cnn, 
                 fix_embedding=args.fix_embedding, logger=logger)
    logger.info('Building the optimizer ...')
    optimizer = Adam(params=model.parameters(), lr=args.lr)
    #optimizer = SGD(params=model.parameters(), lr=args.lr, momentum=0.9)
    l1_regularization = _Regularization(model, 1, p=1, logger=logger)
    l2_regularization = _Regularization(model, 0.005, p=2, logger=logger)

    if args.start_from is not None:
        logger.info('Loading pretrained model from {} ...'.format(args.start_from))
        ckpt = torch.load(args.start_from, map_location='cpu')
        model.load_state_dict(ckpt['model'])
        optimizer.load_state_dict(ckpt['optimizer'])
    if args.gpu_id != -1:
        model.cuda(args.gpu_id)

    batch_acm = 0
    global_step = 0
    loss_p_xz_acm, loss_q_zx_acm, loss_image_l2_acm, loss_sketch_l2_acm, loss_reg_l2_acm, loss_reg_l1_acm = 0., 0., 0., 0., 0., 0.,
    best_precision = 0.
    best_iter = 0
    patience = args.patience
    logger.info('Hyper-Parameter:')
    logger.info(args)
    logger.info('Model Structure:')
    logger.info(model)
    logger.info('Begin Training !')
    while True:
        if patience <= 0:
            break
        for sketch_batch, image_batch, semantics_batch in dataloader_train:
            if global_step % args.print_every == 0 % args.print_every and global_step and batch_acm % args.cum_num == 0:
                logger.info('Iter {}, Loss/p_xz {:.3f}, Loss/q_zx {:.3f}, Loss/image_l2 {:.3f}, Loss/sketch_l2 {:.3f}, Loss/reg_l2 {:.3f}, Loss/reg_l1 {:.3f}'.format(global_step, \
                             loss_p_xz_acm/args.print_every/args.cum_num, \
                             loss_q_zx_acm/args.print_every/args.cum_num, \
                             loss_image_l2_acm/args.print_every/args.cum_num, \
                             loss_sketch_l2_acm/args.print_every/args.cum_num, \
                             loss_reg_l2_acm/args.print_every/args.cum_num, \
                             loss_reg_l1_acm/args.print_every/args.cum_num))
                loss_p_xz_acm, loss_q_zx_acm, loss_image_l2_acm, loss_sketch_l2_acm, loss_reg_l2_acm, loss_reg_l1_acm = 0., 0., 0., 0., 0., 0.,

            if global_step % args.save_every == 0 % args.save_every and batch_acm % args.cum_num == 0 and global_step :
                if not os.path.exists(args.save_dir):
                    os.mkdir(args.save_dir)
                torch.save({'args':args, 'model':model.state_dict(), \
                        'optimizer':optimizer.state_dict()},
                        '{}/Iter_{}.pkl'.format(args.save_dir,global_step))

                ### Evaluation
                model.eval()

                image_label = list()
                image_feature = list()
                for image, label in data.load_test_images(batch_size=args.batch_size):
                    image = image.cuda(args.gpu_id)
                    image_label += label
                    tmp_feature = model.hash(image, 1).cpu().detach().numpy()
                    image_feature.append(tmp_feature)
                image_feature = np.vstack(image_feature)

                sketch_label = list()
                sketch_feature = list()
                for sketch, label in data.load_test_sketch(batch_size=args.batch_size):
                    sketch = sketch.cuda(args.gpu_id)
                    sketch_label += label
                    tmp_feature = model.hash(sketch, 0).cpu().detach().numpy()
                    sketch_feature.append(tmp_feature)
                sketch_feature = np.vstack(sketch_feature)

                dists_cosine = cdist(image_feature, sketch_feature, 'hamming')

                rank_cosine = np.argsort(dists_cosine, 0)

                for n in [5, 100, 200]:
                    ranksn_cosine = rank_cosine[:n, :].T

                    classesn_cosine = np.array([[image_label[i] == sketch_label[r] \
                                                for i in ranksn_cosine[r]] for r in range(len(ranksn_cosine))])

                    precision_cosine = np.mean(classesn_cosine)

                    writer.add_scalar('Precision_{}/cosine'.format(n),
                            precision_cosine, global_step)

                    logger.info('Iter {}, Precision_{}/cosine {}'.format(global_step, n, precision_cosine))

                if best_precision < precision_cosine:
                    patience = args.patience
                    best_precision = precision_cosine
                    best_iter = global_step
                    writer.add_scalar('Best/Precision_200', best_precision, best_iter)
                    logger.info('Iter {}, Best Precision_200 {}'.format(global_step, best_precision))
                    torch.save({'args':args, 'model':model.state_dict(), \
                        'optimizer':optimizer.state_dict()}, '{}/Best.pkl'.format(args.save_dir))
                else:
                    patience -= 1
            if patience <= 0:
                break

            model.train()
            batch_acm += 1
            if global_step <= args.warmup_steps:
                update_lr(optimizer, args.lr*global_step/args.warmup_steps)
            """
            #code for testing if the images and the sketches are corresponding to each other correctly

            for i in range(args.batch_size):
                sk = sketch_batch[i].numpy().reshape(224, 224, 3)
                im = image_batch[i].numpy().reshape(224, 224, 3)
                print(label[i])
                ims = np.vstack((np.uint8(sk), np.uint8(im)))
                cv2.imshow('test', ims)
                cv2.waitKey(3000)
            """

            sketch = sketch_batch.cuda(args.gpu_id)
            image = image_batch.cuda(args.gpu_id)
            semantics = semantics_batch.long().cuda(args.gpu_id)

            optimizer.zero_grad()
            loss = model(sketch, image, semantics)
            loss_l1 = l1_regularization()
            loss_l2 = l2_regularization()
            loss_p_xz_acm += loss['p_xz'][0].item()
            loss_q_zx_acm += loss['q_zx'][0].item()
            loss_image_l2_acm += loss['image_l2'][0].item()
            loss_sketch_l2_acm += loss['sketch_l2'][0].item()
            loss_reg_l1_acm += loss_l1.item()
            loss_reg_l2_acm += (loss_l2.item() / 0.005)
            writer.add_scalar('Loss/p_xz', loss['p_xz'][0].item(), global_step)
            writer.add_scalar('Loss/q_zx', loss['q_zx'][0].item(), global_step)
            writer.add_scalar('Loss/image_l2', loss['image_l2'][0].item(), global_step)
            writer.add_scalar('Loss/sketch_l2', loss['sketch_l2'][0].item(), global_step)
            writer.add_scalar('Loss/reg_l2', (loss_l2.item() / 0.005), global_step)
            writer.add_scalar('Loss/reg_l1', loss_l1.item(), global_step)
            
            loss_ = loss_l2
            for item in loss.values():
                loss_ += item[0]*item[1]
            loss_.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            if batch_acm % args.cum_num == 0:
                optimizer.step()
                global_step += 1
示例#2
0
def train(args):
    # srun -p gpu --gres=gpu:1 python main_dsh.py
    sketch_folder, imsk_folder, im_folder, path_semantic, train_class, test_class = _parse_args_paths(
        args)
    logger = make_logger(join(mkdir(args.save_dir), curr_time_str() + '.log'))
    if DEBUG:
        train_class = train_class[:2]
        test_class = test_class[:2]
        args.print_every = 2
        args.save_every = 8
        args.steps = 20
        args.batch_size = 2
        args.npy_dir = NPY_FOLDER_SKETCHY

    # logger.info("try loading data_train")
    data_train = DSH_dataloader(folder_sk=sketch_folder,
                                folder_im=im_folder,
                                clss=train_class,
                                folder_nps=args.npy_dir,
                                folder_imsk=imsk_folder,
                                normalize01=False,
                                doaug=False,
                                m=args.m,
                                path_semantic=path_semantic,
                                folder_saving=join(mkdir(args.save_dir),
                                                   'train_saving'),
                                logger=logger)
    dataloader_train = DataLoader(dataset=data_train,
                                  batch_size=args.batch_size,
                                  shuffle=False)
    # logger.info("try loading data_test")
    data_test = DSH_dataloader(folder_sk=sketch_folder,
                               clss=test_class,
                               folder_nps=args.npy_dir,
                               path_semantic=path_semantic,
                               folder_imsk=imsk_folder,
                               normalize01=False,
                               doaug=False,
                               m=args.m,
                               folder_saving=join(mkdir(args.save_dir),
                                                  'test_saving'),
                               logger=logger)

    model = DSH(m=args.m, config=args.config)
    model.cuda()

    optimizer = SGD(params=model.parameters(), lr=args.lr, momentum=0.9)

    # logger.info("optimizer inited")
    steps = _try_load(args, logger, model, optimizer)
    logger.info(str(args))
    args.steps += steps
    dsh_loss = _DSH_loss(gamma=args.gamma)
    model.train()
    l2_regularization = _Regularization(model, args.l2_reg, p=2, logger=None)
    loss_sum = []
    # logger.info("iterations")
    # iterations
    while True:
        # logger.info("update D")
        # 1. update D
        data_train.D = update_D(bi=data_train.BI,
                                bs=data_train.BS,
                                vec_bi=data_train.vec_bi,
                                vec_bs=data_train.vec_bs)
        # logger.info("update BI/BS")
        # 2. update BI/BS
        feats_labels_sk, feats_labels_im = _extract_feats_sk_im(
            data=data_train, model=model, batch_size=args.batch_size)

        data_train.BI, data_train.BS = update_B(bi=data_train.BI,
                                                bs=data_train.BS,
                                                vec_bi=data_train.vec_bi,
                                                vec_bs=data_train.vec_bs,
                                                W=data_train.W,
                                                D=data_train.D,
                                                Fi=feats_labels_im[0],
                                                Fs=feats_labels_sk[0],
                                                lamb=args.lamb,
                                                gamma=args.gamma)
        # logger.info("update network parameters")
        # 3. update network parameters
        for _, (sketch, code_of_sketch, image, sketch_token,
                code_of_image) in enumerate(dataloader_train):

            sketch_feats, im_feats = model(sketch.cuda(), sketch_token.cuda(),
                                           image.cuda())
            loss = dsh_loss(sketch_feats, im_feats, code_of_sketch.cuda(), code_of_image.cuda()) \
                    + l2_regularization()
            loss = loss / args.update_every
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            if (steps + 1) % args.update_every == 0:
                optimizer.step()
                optimizer.zero_grad()
            loss_sum.append(float(loss.item() * args.update_every))
            if (steps + 1) % args.save_every == 0:
                _test_and_save(steps=steps,
                               optimizer=optimizer,
                               data_test=data_test,
                               model=model,
                               logger=logger,
                               args=args,
                               loss_sum=loss_sum)
                data_train.save_params()

            if (steps + 1) % args.print_every == 0:
                loss_sum = [np.mean(loss_sum)]
                logger.info('step: {},  loss: {}'.format(steps, loss_sum[0]))

            steps += 1
            if steps >= args.steps: break
        dr_dec(optimizer=optimizer, args=args)
        if steps >= args.steps: break
示例#3
0
def train(args):
    # srun -p gpu --gres=gpu:1 --output=d3shape_sketchy.out python main_d3shape.py --steps 50000 --print_every 200 --npy_dir 0 --save_every 1000 --batch_size 8 --dataset sketchy --save_dir d3shape_sketchy

    sketch_folder, imsk_folder, train_class, test_class = _parse_args_paths(
        args)

    data_train = D3Shape_dataloader(folder_sk=sketch_folder,
                                    clss=train_class,
                                    folder_nps=args.npy_dir,
                                    folder_imsk=imsk_folder,
                                    normalize01=False,
                                    doaug=False)
    dataloader_train = DataLoader(dataset=data_train,
                                  batch_size=args.batch_size,
                                  shuffle=False)

    data_test = D3Shape_dataloader(folder_sk=sketch_folder,
                                   clss=test_class,
                                   folder_nps=args.npy_dir,
                                   folder_imsk=imsk_folder,
                                   normalize01=False,
                                   doaug=False)

    model = D3Shape()
    model.cuda()
    optimizer = Adam(params=model.parameters(), lr=args.lr)
    logger = make_logger(join(mkdir(args.save_dir), curr_time_str() + '.log'))
    steps = _try_load(args, logger, model, optimizer)
    logger.info(str(args))
    args.steps += steps
    d3shape_loss = _D3Shape_loss(cp=args.cp, cn=args.cn)
    model.train()
    l2_regularization = _Regularization(model, args.l2_reg, p=2, logger=None)
    while True:
        loss_sum = []
        for _, (sketch1, imsk1, sketch2, imsk2,
                is_same) in enumerate(dataloader_train):
            optimizer.zero_grad()
            sketch1_feats, imsk1_feats = model(sketch1.cuda(), imsk1.cuda())
            sketch2_feats, imsk2_feats = model(sketch2.cuda(), imsk2.cuda())
            loss = d3shape_loss(sketch1_feats, imsk1_feats, sketch2_feats, imsk2_feats, is_same.cuda()) \
                    + l2_regularization()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            loss_sum.append(float(loss.item()))
            if (steps + 1) % args.save_every == 0:
                model.eval()
                n = 50
                skip = 1
                start_cpu_t = time.time()
                feats_labels_sk = _extract_feats(data_test,
                                                 lambda sk: model(sk, None)[0],
                                                 SK,
                                                 skip=skip,
                                                 batch_size=args.batch_size)
                feats_labels_imsk = _extract_feats(
                    data_test,
                    lambda imsk: model(None, imsk)[0],
                    IMSK,
                    skip=skip,
                    batch_size=args.batch_size)
                pre, mAP = _eval(feats_labels_sk, feats_labels_imsk, n)
                logger.info(
                    "Precision@{}: {}, mAP@{}: {}".format(n, pre, n, mAP) +
                    "  " + 'step: {},  loss: {},  (eval cpu time: {}s)'.format(
                        steps, np.mean(loss_sum),
                        time.time() - start_cpu_t))
                torch.save(
                    {
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'steps': steps,
                        'args': args
                    }, save_fn(args.save_dir, steps, pre, mAP))
                model.train()

            if (steps + 1) % args.print_every == 0:
                logger.info('step: {},  loss: {}'.format(
                    steps, np.mean(loss_sum)))
                loss_sum = []

            steps += 1
            if steps >= args.steps: break
        if steps >= args.steps: break
示例#4
0
def train(args):
    # srun -p gpu --gres=gpu:1 --exclusive --output=san10.out python main_san.py --epochs 50000 --print_every 500 --save_every 2000 --batch_size 96 --dataset sketchy --margin 10 --npy_dir 0 --save_dir san_sketchy10
    # srun -p gpu --gres=gpu:1 --exclusive --output=san1.out python main_san.py --epochs 50000 --print_every 500 --save_every 2000 --batch_size 96 --dataset sketchy --margin 1 --npy_dir 0 --save_dir san_sketchy1

    # srun -p gpu --gres=gpu:1 --output=san_sketchy03.out python main_san.py --epochs 30000 --print_every 200 --save_every 3000 --batch_size 96 --dataset sketchy --margin 0.3 --npy_dir 0 --save_dir san_sketchy03 --lr 0.0001
    sketch_folder, image_folder, path_semantic, train_class, test_class = _parse_args_paths(
        args)

    if DEBUG:
        args.back_bone = 'default'
        args.npy_dir = NPY_FOLDER_SKETCHY
        args.ni_path = PATH_NAMES
        args.print_every = 1
        args.save_every = 5
        args.paired = True
        args.epochs = 20000
        # args.lr = 0.001
        args.sz = 32
        # args.l2_reg = 0.0001
        args.back_bone = 'default'
        args.batch_size = 32
        args.h = 500

        test_class = train_class[5:7]
        train_class = train_class[:5]
    logger = make_logger(join(mkdir(args.save_dir), curr_time_str() + '.log'))
    data_train = CMT_dataloader(
        folder_sk=sketch_folder,
        clss=train_class,
        folder_nps=args.npy_dir,
        path_semantic=path_semantic,
        paired=args.paired,
        names=args.ni_path,
        folder_im=image_folder,
        normalize01=False,
        doaug=False,
        logger=logger,
        sz=None if args.back_bone == 'vgg' else args.sz)
    dataloader_train = DataLoader(dataset=data_train,
                                  batch_size=args.batch_size,
                                  shuffle=True)

    data_test = CMT_dataloader(folder_sk=sketch_folder,
                               clss=test_class,
                               folder_nps=args.npy_dir,
                               path_semantic=path_semantic,
                               folder_im=image_folder,
                               normalize01=False,
                               doaug=False,
                               logger=logger,
                               sz=None if args.back_bone == 'vgg' else args.sz)

    model = CMT(d=data_train.d(),
                h=args.h,
                back_bone=args.back_bone,
                batch_normalization=args.bn,
                sz=args.sz)
    model.cuda()

    if not args.ft:
        model.fix_vgg()
    optimizer = SGD(params=model.parameters(), lr=args.lr, momentum=0.6)

    epochs = _try_load(args, logger, model, optimizer)
    logger.info(str(args))
    args.epochs += epochs
    cmt_loss = _CMT_loss()
    model.train()

    l2_regularization = _Regularization(model, args.l2_reg, p=2, logger=None)
    loss_sum = [[0], [0]]
    logger.info(
        "Start training:\n train_classes: {}\n test_classes: {}".format(
            train_class, test_class))
    _test_and_save(epochs=epochs,
                   optimizer=optimizer,
                   data_test=data_test,
                   model=model,
                   logger=logger,
                   args=args,
                   loss_sum=loss_sum)
    while True:
        for mode, get_feat in [[IM, lambda data: model(im=data)],
                               [SK, lambda data: model(sk=data)]]:
            data_train.mode = mode
            for _, (data, semantics) in enumerate(dataloader_train):

                # Skip one-element batch in consideration of batch normalization
                if data.shape[0] == 1:
                    continue
                # print(data.shape)
                optimizer.zero_grad()
                loss = cmt_loss(get_feat(data.cuda()),
                                semantics.cuda()) \
                        + l2_regularization()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                loss_sum[mode].append(float(loss.item()))
        epochs += 1
        dr_dec(optimizer=optimizer, args=args)
        if (epochs + 1) % args.save_every == 0:
            _test_and_save(epochs=epochs,
                           optimizer=optimizer,
                           data_test=data_test,
                           model=model,
                           logger=logger,
                           args=args,
                           loss_sum=loss_sum)

        if (epochs + 1) % args.print_every == 0:
            logger.info('epochs: {},  loss_sk: {},  loss_im: {},'.format(
                epochs, np.mean(loss_sum[SK]), np.mean(loss_sum[IM])))
            loss_sum = [[], []]

        if epochs >= args.epochs: break
示例#5
0
def train(args):
    writer = SummaryWriter()
    logger = make_logger(args.log_file)

    if args.zs:
        packed = args.packed_pkl_zs
    else:
        packed = args.packed_pkl_nozs

    logger.info('Loading the data ...')
    data = CMDTrans_data(args.sketch_dir,
                         args.image_dir,
                         args.stats_file,
                         args.embedding_file,
                         packed,
                         args.preprocess_data,
                         args.raw_data,
                         zs=args.zs,
                         sample_time=1,
                         cvae=True,
                         paired=False,
                         cut_part=False)
    dataloader_train = DataLoader(dataset=data, num_workers=args.num_worker, \
                                  batch_size=args.batch_size,
                                  shuffle=args.shuffle)
    logger.info('Training sketch size: {}'.format(
        len(data.path2class_sketch.keys())))
    logger.info('Training image size: {}'.format(
        len(data.path2class_image.keys())))
    logger.info('Testing sketch size: {}'.format(
        len(data.path2class_sketch_test.keys())))
    logger.info('Testing image size: {}'.format(
        len(data.path2class_image_test.keys())))

    logger.info('Building the model ...')
    model = Regressor(args.raw_size,
                      args.hidden_size,
                      dropout_prob=args.dropout,
                      logger=logger)
    logger.info('Building the optimizer ...')
    optimizer = Adam(params=model.parameters(), lr=args.lr, betas=(0.5, 0.999))
    l1_regularization = _Regularization(model,
                                        args.l1_weight,
                                        p=1,
                                        logger=logger)
    l2_regularization = _Regularization(model,
                                        args.l2_weight,
                                        p=2,
                                        logger=logger)

    if args.start_from is not None:
        logger.info('Loading pretrained model from {} ...'.format(
            args.start_from))
        ckpt = torch.load(args.start_from, map_location='cpu')
        model.load_state_dict(ckpt['model'])
        optimizer.load_state_dict(ckpt['optimizer'])
    if args.gpu_id != -1:
        model.cuda(args.gpu_id)
    optimizer.zero_grad()

    loss_tri_acm = 0.
    loss_l1_acm = 0.
    loss_l2_acm = 0.
    batch_acm = 0
    global_step = 0
    best_precision = 0.
    best_iter = 0
    patience = args.patience
    logger.info('Hyper-Parameter:')
    logger.info(args)
    logger.info('Model Structure:')
    logger.info(model)
    logger.info('Begin Training !')
    while True:
        if patience <= 0:
            break
        for sketch_batch, image_p_batch, image_n_batch, _semantics_batch in dataloader_train:
            sketch_batch = sketch_batch.float()
            image_p_batch = image_p_batch.float()
            image_n_batch = image_n_batch.float()
            if global_step % args.print_every == 0 % args.print_every and global_step and batch_acm % args.cum_num == 0:
                logger.info('*** Iter {} ***'.format(global_step))
                logger.info('        Loss/Triplet {:.3}'.format(
                    loss_tri_acm / args.print_every / args.cum_num))
                logger.info('        Loss/L1 {:.3}'.format(
                    loss_l1_acm / args.print_every / args.cum_num))
                logger.info('        Loss/L2 {:.3}'.format(
                    loss_l2_acm / args.print_every / args.cum_num))
                loss_tri_acm = 0.
                loss_l1_acm = 0.
                loss_l2_acm = 0.

            if global_step % args.save_every == 0 % args.save_every and batch_acm % args.cum_num == 0 and global_step:
                if not os.path.exists(args.save_dir):
                    os.mkdir(args.save_dir)
                torch.save(
                    {
                        'args': args,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }, '{}/Iter_{}.pkl'.format(args.save_dir, global_step))

                ### Evaluation
                model.eval()

                image_label = list()
                image_feature = list()
                for image, label in data.load_test_images(
                        batch_size=args.batch_size):
                    image = image.float()
                    if args.gpu_id != -1:
                        image = image.cuda(args.gpu_id)
                    image_label += label
                    tmp_feature = model.inference_image(
                        image).cpu().detach().numpy()
                    image_feature.append(tmp_feature)
                image_feature = np.vstack(image_feature)

                sketch_label = list()
                sketch_feature = list()
                for sketch, label in data.load_test_sketch(
                        batch_size=args.batch_size):
                    sketch = sketch.float()
                    if args.gpu_id != -1:
                        sketch = sketch.cuda(args.gpu_id)
                    sketch_label += label
                    tmp_feature = model.inference_sketch(
                        sketch).cpu().detach().numpy()
                    sketch_feature.append(tmp_feature)
                sketch_feature = np.vstack(sketch_feature)

                Precision, mAP, = cal_matrics_single(image_feature,
                                                     image_label,
                                                     sketch_feature,
                                                     sketch_label)

                writer.add_scalar('Precision_200/cosine', Precision,
                                  global_step)
                writer.add_scalar('mAP_200/cosine', mAP, global_step)
                logger.info('*** Evaluation Iter {} ***'.format(global_step))
                logger.info('        Precision {:.3}'.format(Precision))
                logger.info('        mAP {:.3}'.format(mAP))

                if best_precision < Precision:
                    patience = args.patience
                    best_precision = Precision
                    best_iter = global_step
                    writer.add_scalar('Best/Precision_200', best_precision,
                                      best_iter)
                    logger.info('Iter {}, Best Precision_200 {:.3}'.format(
                        global_step, best_precision))
                    torch.save({'args':args, 'model':model.state_dict(), \
                        'optimizer':optimizer.state_dict()}, '{}/Best.pkl'.format(args.save_dir))
                else:
                    patience -= 1
            if patience <= 0:
                break

            model.train()
            batch_acm += 1
            if global_step <= args.warmup_steps:
                update_lr(optimizer, args.lr * global_step / args.warmup_steps)

            if args.gpu_id != -1:
                sketch_batch = sketch_batch.cuda(args.gpu_id)
                image_p_batch = image_p_batch.cuda(args.gpu_id)
                image_n_batch = image_n_batch.cuda(args.gpu_id)

            loss = model(sketch_batch, image_p_batch, image_n_batch)

            loss_l1 = l1_regularization()
            loss_l2 = l2_regularization()
            loss_tri = loss.item()

            loss_l1_acm += (loss_l1.item() / args.l1_weight)
            loss_l2_acm += (loss_l2.item() / args.l2_weight)
            loss_tri_acm += loss_tri

            writer.add_scalar('Loss/Triplet', loss_tri, global_step)
            writer.add_scalar('Loss/Reg_l1', (loss_l1.item() / args.l1_weight),
                              global_step)
            writer.add_scalar('Loss/Reg_l2', (loss_l2.item() / args.l2_weight),
                              global_step)

            loss_ = 0
            loss_ += loss
            loss_.backward()

            if batch_acm % args.cum_num == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                global_step += 1
                optimizer.zero_grad()
示例#6
0
def train(args):
    writer = SummaryWriter()
    logger = make_logger(args.log_file)

    if args.zs:
        packed = args.packed_pkl_zs
    else:
        packed = args.packed_pkl_nozs

    data = Siamese_dataloader(args.sketch_dir,
                              args.image_dir,
                              args.stats_file,
                              packed,
                              zs=args.zs)
    print(len(data))
    dataloader_train = DataLoader(dataset=data, num_workers=args.num_worker, \
                                  batch_size=args.batch_size,
                                  shuffle=args.shuffle)

    logger.info('Building the model ...')
    model = Siamese(args.margin,
                    args.loss_type,
                    args.distance_type,
                    batch_normalization=False,
                    from_pretrain=True,
                    logger=logger)
    logger.info('Building the optimizer ...')
    #optimizer = Adam(params=model.parameters(), lr=args.lr)
    optimizer = SGD(params=model.parameters(), lr=args.lr, momentum=0.9)
    siamese_loss = _Siamese_loss()
    l1_regularization = _Regularization(model, 0.1, p=1, logger=logger)
    l2_regularization = _Regularization(model, 1e-4, p=2, logger=logger)

    if args.start_from is not None:
        logger.info('Loading pretrained model from {} ...'.format(
            args.start_from))
        ckpt = torch.load(args.start_from, map_location='cpu')
        model.load_state_dict(ckpt['model'])
        optimizer.load_state_dict(ckpt['optimizer'])
    if args.gpu_id != -1:
        model.cuda(args.gpu_id)

    batch_acm = 0
    global_step = 0
    loss_siamese_acm, sim_acm, dis_sim_acm, loss_l1_acm, loss_l2_acm = 0., 0., 0., 0., 0.,
    best_precision = 0.
    best_iter = 0
    patience = args.patience
    logger.info('Hyper-Parameter:')
    logger.info(args)
    logger.info('Model Structure:')
    logger.info(model)
    logger.info('Begin Training !')
    while True:
        if patience <= 0:
            break
        for sketch_batch, image_batch, label_batch in dataloader_train:
            if global_step % args.print_every == 0 % args.print_every and global_step and batch_acm % args.cum_num == 0:
                logger.info('Iter {}, Loss/siamese {:.3f}, Loss/l1 {:.3f}, Loss/l2 {:.3f}, Siamese/sim {:.3f}, Siamese/dis_sim {:.3f}'.format(global_step, \
                             loss_siamese_acm/args.print_every/args.cum_num, \
                             loss_l1_acm/args.print_every/args.cum_num, \
                             loss_l2_acm/args.print_every/args.cum_num, \
                             sim_acm/args.print_every/args.cum_num, \
                             dis_sim_acm/args.print_every/args.cum_num))
                loss_siamese_acm, sim_acm, dis_sim_acm, loss_l1_acm, loss_l2_acm = 0., 0., 0., 0., 0.,

            if global_step % args.save_every == 0 % args.save_every and batch_acm % args.cum_num == 0 and global_step:
                if not os.path.exists(args.save_dir):
                    os.mkdir(args.save_dir)
                torch.save({'args':args, 'model':model.state_dict(), \
                        'optimizer':optimizer.state_dict()},
                        '{}/Iter_{}.pkl'.format(args.save_dir,global_step))

                ### Evaluation
                model.eval()

                image_label = list()
                image_feature = list()
                for image, label in data.load_test_images(
                        batch_size=args.batch_size):
                    image = image.cuda(args.gpu_id)
                    image_label += label
                    tmp_feature = model.get_feature(
                        image).cpu().detach().numpy()
                    image_feature.append(tmp_feature)
                image_feature = np.vstack(image_feature)

                sketch_label = list()
                sketch_feature = list()
                for sketch, label in data.load_test_sketch(
                        batch_size=args.batch_size):
                    sketch = sketch.cuda(args.gpu_id)
                    sketch_label += label
                    tmp_feature = model.get_feature(
                        sketch).cpu().detach().numpy()
                    sketch_feature.append(tmp_feature)
                sketch_feature = np.vstack(sketch_feature)

                dists_cosine = cdist(image_feature, sketch_feature, 'cosine')
                print(dists_cosine.shape)
                dists_euclid = cdist(image_feature, sketch_feature,
                                     'euclidean')

                rank_cosine = np.argsort(dists_cosine, 0)
                rank_euclid = np.argsort(dists_euclid, 0)

                for n in [5, 200]:
                    ranksn_cosine = rank_cosine[:n, :].T
                    ranksn_euclid = rank_euclid[:n, :].T

                    classesn_cosine = np.array([[image_label[i] == sketch_label[r] \
                                                for i in ranksn_cosine[r]] for r in range(len(ranksn_cosine))])
                    classesn_euclid = np.array([[image_label[i] == sketch_label[r] \
                                                for i in ranksn_euclid[r]] for r in range(len(ranksn_euclid))])

                    precision_cosine = np.mean(classesn_cosine)
                    precision_euclid = np.mean(classesn_euclid)

                    writer.add_scalar('Precision_{}/cosine'.format(n),
                                      precision_cosine, global_step)
                    writer.add_scalar('Precision_{}/euclid'.format(n),
                                      precision_euclid, global_step)

                    logger.info('Iter {}, Precision_{}/cosine {}'.format(
                        global_step, n, precision_cosine))
                    logger.info('Iter {}, Precision_{}/euclid {}'.format(
                        global_step, n, precision_euclid))

                if best_precision < precision_cosine:
                    patience = args.patience
                    best_precision = precision_cosine
                    best_iter = global_step
                    writer.add_scalar('Best/Precision_200', best_precision,
                                      best_iter)
                    logger.info('Iter {}, Best Precision_200 {}'.format(
                        global_step, best_precision))
                    torch.save({'args':args, 'model':model.state_dict(), \
                        'optimizer':optimizer.state_dict()}, '{}/Best.pkl'.format(args.save_dir))
                else:
                    patience -= 1
            if patience <= 0:
                break

            model.train()
            batch_acm += 1
            if global_step <= args.warmup_steps:
                update_lr(optimizer, args.lr * global_step / args.warmup_steps)
            """
            #code for testing if the images and the sketches are corresponding to each other correctly

            for i in range(args.batch_size):
                sk = sketch_batch[i].numpy().reshape(224, 224, 3)
                im = image_batch[i].numpy().reshape(224, 224, 3)
                print(label[i])
                ims = np.vstack((np.uint8(sk), np.uint8(im)))
                cv2.imshow('test', ims)
                cv2.waitKey(3000)
            """

            sketch = sketch_batch.cuda(args.gpu_id)
            image = image_batch.cuda(args.gpu_id)
            label = label_batch.float().cuda(args.gpu_id)

            optimizer.zero_grad()
            sketch_feature, image_feature = model(sketch, image)
            loss_siamese, sim, dis_sim = siamese_loss(
                sketch_feature,
                image_feature,
                label,
                args.margin,
                loss_type=args.loss_type,
                distance_type=args.distance_type)
            loss_l1 = l1_regularization()
            loss_l2 = l2_regularization()
            loss_siamese_acm += loss_siamese.item()
            sim_acm += sim.item()
            dis_sim_acm += dis_sim.item()
            loss_l1_acm += loss_l1.item()
            loss_l2_acm += loss_l2.item()
            writer.add_scalar('Loss/Siamese', loss_siamese.item(), global_step)
            writer.add_scalar('Loss/L1', loss_l1.item(), global_step)
            writer.add_scalar('Loss/L2', loss_l2.item(), global_step)
            writer.add_scalar('Siamese/Similar', sim.item(), global_step)
            writer.add_scalar('Siamese/Dis-Similar', dis_sim.item(),
                              global_step)
            loss = loss_siamese + loss_l2
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            if batch_acm % args.cum_num == 0:
                optimizer.step()
                global_step += 1
示例#7
0
def train(args):
    relu = nn.ReLU(inplace=True)
    writer = SummaryWriter()
    logger = make_logger(args.log_file)

    if args.zs:
        packed = args.packed_pkl_zs
    else:
        packed = args.packed_pkl_nozs

    logger.info('Loading the data ...')
    data = CMDTrans_data(args.sketch_dir,
                         args.image_dir,
                         args.stats_file,
                         args.embedding_file,
                         packed,
                         args.preprocess_data,
                         args.raw_data,
                         zs=args.zs,
                         sample_time=1,
                         cvae=True,
                         paired=True,
                         cut_part=False,
                         ranking=True,
                         tu_berlin=args.tu_berlin,
                         strong_pair=args.strong_pair)
    dataloader_train = DataLoader(dataset=data,
                                  num_workers=args.num_worker,
                                  batch_size=args.batch_size,
                                  shuffle=args.shuffle)
    logger.info('Training sketch size: {}'.format(
        len(data.path2class_sketch.keys())))
    logger.info('Training image size: {}'.format(
        len(data.path2class_image.keys())))
    logger.info('Testing sketch size: {}'.format(
        len(data.path2class_sketch_test.keys())))
    logger.info('Testing image size: {}'.format(
        len(data.path2class_image_test.keys())))

    logger.info('Building the model ...')
    model = CMDTrans_model(args.pca_size,
                           args.raw_size,
                           args.hidden_size,
                           args.semantics_size,
                           data.pretrain_embedding.float(),
                           dropout_prob=args.dropout,
                           fix_embedding=args.fix_embedding,
                           seman_dist=args.seman_dist,
                           triplet_dist=args.triplet_dist,
                           margin1=args.margin1,
                           margin2=args.margin2,
                           logger=logger)
    logger.info('Building the optimizer ...')
    optimizer = Adam(params=model.parameters(), lr=args.lr, betas=(0.5, 0.999))
    #optimizer = SGD(params=model.parameters(), lr=args.lr, momentum=0.9)
    l1_regularization = _Regularization(model,
                                        args.l1_weight,
                                        p=1,
                                        logger=logger)
    l2_regularization = _Regularization(model,
                                        args.l2_weight,
                                        p=2,
                                        logger=logger)

    if args.start_from is not None:
        logger.info('Loading pretrained model from {} ...'.format(
            args.start_from))
        ckpt = torch.load(args.start_from, map_location='cpu')
        model.load_state_dict(ckpt['model'])
        optimizer.load_state_dict(ckpt['optimizer'])
    if args.gpu_id != -1:
        model.cuda(args.gpu_id)
    optimizer.zero_grad()

    # six design loss and two reg loss
    loss_triplet_acm = 0.
    loss_orth_acm = 0.
    loss_kl_acm = 0.
    loss_img_acm = 0.
    loss_ske_acm = 0.
    loss_l1_acm = 0.
    loss_l2_acm = 0.
    # loading batch and optimization step
    batch_acm = 0
    global_step = 0
    # best recoder
    best_precision = 0.
    best_iter = 0
    patience = args.patience
    logger.info('Hyper-Parameter:')
    logger.info(args)
    logger.info('Model Structure:')
    logger.info(model)
    logger.info('Begin Training !')
    loss_weight = dict([('kl', 1.0), ('triplet', 1.0), ('orthogonality', 0.01),
                        ('image', 1.0), ('sketch', 10.0)])
    while True:
        if patience <= 0:
            break
        for sketch_batch, image_pair_batch, image_unpair_batch, image_n_batch in dataloader_train:
            if global_step % args.print_every == 0 % args.print_every and global_step and batch_acm % args.cum_num == 0:
                logger.info('*** Iter {} ***'.format(global_step))
                logger.info('        Loss/Triplet {:.3}'.format(
                    loss_triplet_acm / args.print_every / args.cum_num))
                logger.info('        Loss/Orthogonality {:.3}'.format(
                    loss_orth_acm / args.print_every / args.cum_num))
                logger.info('        Loss/KL {:.3}'.format(
                    loss_kl_acm / args.print_every / args.cum_num))
                logger.info('        Loss/Image {:.3}'.format(
                    loss_img_acm / args.print_every / args.cum_num))
                logger.info('        Loss/Sketch {:.3}'.format(
                    loss_ske_acm / args.print_every / args.cum_num))
                logger.info('        Loss/L1 {:.3}'.format(
                    loss_l1_acm / args.print_every / args.cum_num))
                logger.info('        Loss/L2 {:.3}'.format(
                    loss_l2_acm / args.print_every / args.cum_num))
                loss_triplet_acm = 0.
                loss_orth_acm = 0.
                loss_kl_acm = 0.
                loss_img_acm = 0.
                loss_ske_acm = 0.
                loss_l1_acm = 0.
                loss_l2_acm = 0.

            if global_step % args.save_every == 0 % args.save_every and batch_acm % args.cum_num == 0 and global_step:
                if not os.path.exists(args.save_dir):
                    os.mkdir(args.save_dir)
                torch.save(
                    {
                        'args': args,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }, '{}/Iter_{}.pkl'.format(args.save_dir, global_step))

                ### Evaluation
                model.eval()

                image_label = list()
                image_feature1 = list()  # S
                image_feature2 = list()  # G
                for image, label in data.load_test_images(
                        batch_size=args.batch_size):
                    image = relu(image)
                    if args.gpu_id != -1:
                        image = image.float().cuda(args.gpu_id)
                    image_label += label
                    tmp_feature1 = model.inference_structure(
                        image, 'image').detach()  # S
                    tmp_feature2 = image.detach()  # G
                    image_feature1.append(tmp_feature1)
                    image_feature2.append(tmp_feature2)
                image_feature1 = torch.cat(image_feature1)
                image_feature2 = torch.cat(image_feature2)

                sketch_label = list()
                sketch_feature1 = list()  # S
                sketch_feature2 = list()  # G
                for sketch, label in data.load_test_sketch(
                        batch_size=args.batch_size):
                    sketch = relu(sketch)
                    if args.gpu_id != -1:
                        sketch = sketch.float().cuda(args.gpu_id)
                    sketch_label += label
                    tmp_feature1 = model.inference_structure(
                        sketch, 'sketch').detach()  # S
                    tmp_feature2 = model.inference_generation(
                        sketch).detach()  # G
                    sketch_feature1.append(tmp_feature1)
                    sketch_feature2.append(tmp_feature2)
                sketch_feature1 = torch.cat(sketch_feature1)
                sketch_feature2 = torch.cat(sketch_feature2)

                dists_cosine1 = cosine_distance(
                    image_feature1, sketch_feature1).cpu().detach().numpy()
                dists_cosine2 = cosine_distance(
                    image_feature2, sketch_feature2).cpu().detach().numpy()

                Precision_list, mAP_list, lambda_list, Precision_c, mAP_c = \
                    cal_matrics(dists_cosine1, dists_cosine2, image_label, sketch_label)

                logger.info('*** Evaluation Iter {} ***'.format(global_step))
                for idx, item in enumerate(lambda_list):
                    writer.add_scalar('Precision_200/{}'.format(item),
                                      Precision_list[idx], global_step)
                    writer.add_scalar('mAP_200/{}'.format(item), mAP_list[idx],
                                      global_step)
                    logger.info('        Precision/{} {:.3}'.format(
                        item, Precision_list[idx]))
                    logger.info('        mAP/{} {:.3}'.format(
                        item, mAP_list[idx]))
                writer.add_scalar('Precision_200/Compare', Precision_c,
                                  global_step)
                writer.add_scalar('mAP_200/Compare', mAP_c, global_step)
                logger.info(
                    '        Precision/Compare {:.3}'.format(Precision_c))
                logger.info('        mAP/Compare {:.3}'.format(mAP_c))

                Precision_list.append(Precision_c)
                Precision = max(Precision_list)
                if best_precision < Precision:
                    patience = args.patience
                    best_precision = Precision
                    best_iter = global_step
                    writer.add_scalar('Best/Precision_200', best_precision,
                                      best_iter)
                    logger.info(
                        '=== Iter {}, Best Precision_200 {:.3} ==='.format(
                            global_step, best_precision))
                    torch.save({'args':args, 'model':model.state_dict(), \
                        'optimizer':optimizer.state_dict()}, '{}/Best.pkl'.format(args.save_dir))
                else:
                    patience -= 1
            if patience <= 0:
                break

            model.train()
            batch_acm += 1
            if global_step <= args.warmup_steps:
                update_lr(optimizer, args.lr * global_step / args.warmup_steps)

            if args.gpu_id != -1:
                sketch_batch = relu(sketch_batch).float().cuda(args.gpu_id)
                image_pair_batch = relu(image_pair_batch).float().cuda(
                    args.gpu_id)
                image_unpair_batch = relu(image_unpair_batch).float().cuda(
                    args.gpu_id)
                image_n_batch = relu(image_n_batch).float().cuda(args.gpu_id)

            loss = model(sketch_batch, image_pair_batch, image_unpair_batch,
                         image_n_batch)

            loss_l1 = l1_regularization()
            loss_l2 = l2_regularization()
            loss_kl = loss['kl'].item()
            loss_orth = loss['orthogonality'].item()
            loss_triplet = loss['triplet'].item()
            loss_img = loss['image'].item()
            loss_ske = loss['sketch'].item()

            loss_l1_acm += (loss_l1.item() / args.l1_weight)
            loss_l2_acm += (loss_l2.item() / args.l2_weight)
            loss_kl_acm += loss_kl
            loss_orth_acm += loss_orth
            loss_triplet_acm += loss_triplet
            loss_img_acm += loss_img
            loss_ske_acm += loss_ske

            writer.add_scalar('Loss/KL', loss_kl, global_step)
            writer.add_scalar('Loss/Orthogonality', loss_orth, global_step)
            writer.add_scalar('Loss/Triplet', loss_triplet, global_step)
            writer.add_scalar('Loss/Image', loss_img, global_step)
            writer.add_scalar('Loss/Sketch', loss_ske, global_step)
            writer.add_scalar('Loss/Reg_l1', (loss_l1.item() / args.l1_weight),
                              global_step)
            writer.add_scalar('Loss/Reg_l2', (loss_l2.item() / args.l2_weight),
                              global_step)

            loss_ = 0
            loss_ += loss['image'] * loss_weight['image']
            loss_ += loss['sketch'] * loss_weight['sketch']
            loss_ += loss['triplet'] * loss_weight['triplet']
            loss_ += loss['kl'] * loss_weight['kl']
            #loss_ += loss['orthogonality']*loss_weight['orthogonality']
            loss_.backward()

            if batch_acm % args.cum_num == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                global_step += 1
                optimizer.zero_grad()