def train(args): writer = SummaryWriter() logger = make_logger(args.log_file) if args.zs: packed = args.packed_pkl_zs else: packed = args.packed_pkl_nozs data = ZSIH_dataloader(args.sketch_dir, args.image_dir, args.stats_file, args.embedding_file, packed, zs=args.zs) print(len(data)) dataloader_train = DataLoader(dataset=data, num_workers=args.num_worker, \ batch_size=args.batch_size, shuffle=args.shuffle) logger.info('Building the model ...') model = ZSIM(args.hidden_size, args.hashing_bit, args.semantics_size, data.pretrain_embedding.float(), adj_scaler=args.adj_scaler, dropout=args.dropout, fix_cnn=args.fix_cnn, fix_embedding=args.fix_embedding, logger=logger) logger.info('Building the optimizer ...') optimizer = Adam(params=model.parameters(), lr=args.lr) #optimizer = SGD(params=model.parameters(), lr=args.lr, momentum=0.9) l1_regularization = _Regularization(model, 1, p=1, logger=logger) l2_regularization = _Regularization(model, 0.005, p=2, logger=logger) if args.start_from is not None: logger.info('Loading pretrained model from {} ...'.format(args.start_from)) ckpt = torch.load(args.start_from, map_location='cpu') model.load_state_dict(ckpt['model']) optimizer.load_state_dict(ckpt['optimizer']) if args.gpu_id != -1: model.cuda(args.gpu_id) batch_acm = 0 global_step = 0 loss_p_xz_acm, loss_q_zx_acm, loss_image_l2_acm, loss_sketch_l2_acm, loss_reg_l2_acm, loss_reg_l1_acm = 0., 0., 0., 0., 0., 0., best_precision = 0. best_iter = 0 patience = args.patience logger.info('Hyper-Parameter:') logger.info(args) logger.info('Model Structure:') logger.info(model) logger.info('Begin Training !') while True: if patience <= 0: break for sketch_batch, image_batch, semantics_batch in dataloader_train: if global_step % args.print_every == 0 % args.print_every and global_step and batch_acm % args.cum_num == 0: logger.info('Iter {}, Loss/p_xz {:.3f}, Loss/q_zx {:.3f}, Loss/image_l2 {:.3f}, Loss/sketch_l2 {:.3f}, Loss/reg_l2 {:.3f}, Loss/reg_l1 {:.3f}'.format(global_step, \ loss_p_xz_acm/args.print_every/args.cum_num, \ loss_q_zx_acm/args.print_every/args.cum_num, \ loss_image_l2_acm/args.print_every/args.cum_num, \ loss_sketch_l2_acm/args.print_every/args.cum_num, \ loss_reg_l2_acm/args.print_every/args.cum_num, \ loss_reg_l1_acm/args.print_every/args.cum_num)) loss_p_xz_acm, loss_q_zx_acm, loss_image_l2_acm, loss_sketch_l2_acm, loss_reg_l2_acm, loss_reg_l1_acm = 0., 0., 0., 0., 0., 0., if global_step % args.save_every == 0 % args.save_every and batch_acm % args.cum_num == 0 and global_step : if not os.path.exists(args.save_dir): os.mkdir(args.save_dir) torch.save({'args':args, 'model':model.state_dict(), \ 'optimizer':optimizer.state_dict()}, '{}/Iter_{}.pkl'.format(args.save_dir,global_step)) ### Evaluation model.eval() image_label = list() image_feature = list() for image, label in data.load_test_images(batch_size=args.batch_size): image = image.cuda(args.gpu_id) image_label += label tmp_feature = model.hash(image, 1).cpu().detach().numpy() image_feature.append(tmp_feature) image_feature = np.vstack(image_feature) sketch_label = list() sketch_feature = list() for sketch, label in data.load_test_sketch(batch_size=args.batch_size): sketch = sketch.cuda(args.gpu_id) sketch_label += label tmp_feature = model.hash(sketch, 0).cpu().detach().numpy() sketch_feature.append(tmp_feature) sketch_feature = np.vstack(sketch_feature) dists_cosine = cdist(image_feature, sketch_feature, 'hamming') rank_cosine = np.argsort(dists_cosine, 0) for n in [5, 100, 200]: ranksn_cosine = rank_cosine[:n, :].T classesn_cosine = np.array([[image_label[i] == sketch_label[r] \ for i in ranksn_cosine[r]] for r in range(len(ranksn_cosine))]) precision_cosine = np.mean(classesn_cosine) writer.add_scalar('Precision_{}/cosine'.format(n), precision_cosine, global_step) logger.info('Iter {}, Precision_{}/cosine {}'.format(global_step, n, precision_cosine)) if best_precision < precision_cosine: patience = args.patience best_precision = precision_cosine best_iter = global_step writer.add_scalar('Best/Precision_200', best_precision, best_iter) logger.info('Iter {}, Best Precision_200 {}'.format(global_step, best_precision)) torch.save({'args':args, 'model':model.state_dict(), \ 'optimizer':optimizer.state_dict()}, '{}/Best.pkl'.format(args.save_dir)) else: patience -= 1 if patience <= 0: break model.train() batch_acm += 1 if global_step <= args.warmup_steps: update_lr(optimizer, args.lr*global_step/args.warmup_steps) """ #code for testing if the images and the sketches are corresponding to each other correctly for i in range(args.batch_size): sk = sketch_batch[i].numpy().reshape(224, 224, 3) im = image_batch[i].numpy().reshape(224, 224, 3) print(label[i]) ims = np.vstack((np.uint8(sk), np.uint8(im))) cv2.imshow('test', ims) cv2.waitKey(3000) """ sketch = sketch_batch.cuda(args.gpu_id) image = image_batch.cuda(args.gpu_id) semantics = semantics_batch.long().cuda(args.gpu_id) optimizer.zero_grad() loss = model(sketch, image, semantics) loss_l1 = l1_regularization() loss_l2 = l2_regularization() loss_p_xz_acm += loss['p_xz'][0].item() loss_q_zx_acm += loss['q_zx'][0].item() loss_image_l2_acm += loss['image_l2'][0].item() loss_sketch_l2_acm += loss['sketch_l2'][0].item() loss_reg_l1_acm += loss_l1.item() loss_reg_l2_acm += (loss_l2.item() / 0.005) writer.add_scalar('Loss/p_xz', loss['p_xz'][0].item(), global_step) writer.add_scalar('Loss/q_zx', loss['q_zx'][0].item(), global_step) writer.add_scalar('Loss/image_l2', loss['image_l2'][0].item(), global_step) writer.add_scalar('Loss/sketch_l2', loss['sketch_l2'][0].item(), global_step) writer.add_scalar('Loss/reg_l2', (loss_l2.item() / 0.005), global_step) writer.add_scalar('Loss/reg_l1', loss_l1.item(), global_step) loss_ = loss_l2 for item in loss.values(): loss_ += item[0]*item[1] loss_.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) if batch_acm % args.cum_num == 0: optimizer.step() global_step += 1
def train(args): # srun -p gpu --gres=gpu:1 python main_dsh.py sketch_folder, imsk_folder, im_folder, path_semantic, train_class, test_class = _parse_args_paths( args) logger = make_logger(join(mkdir(args.save_dir), curr_time_str() + '.log')) if DEBUG: train_class = train_class[:2] test_class = test_class[:2] args.print_every = 2 args.save_every = 8 args.steps = 20 args.batch_size = 2 args.npy_dir = NPY_FOLDER_SKETCHY # logger.info("try loading data_train") data_train = DSH_dataloader(folder_sk=sketch_folder, folder_im=im_folder, clss=train_class, folder_nps=args.npy_dir, folder_imsk=imsk_folder, normalize01=False, doaug=False, m=args.m, path_semantic=path_semantic, folder_saving=join(mkdir(args.save_dir), 'train_saving'), logger=logger) dataloader_train = DataLoader(dataset=data_train, batch_size=args.batch_size, shuffle=False) # logger.info("try loading data_test") data_test = DSH_dataloader(folder_sk=sketch_folder, clss=test_class, folder_nps=args.npy_dir, path_semantic=path_semantic, folder_imsk=imsk_folder, normalize01=False, doaug=False, m=args.m, folder_saving=join(mkdir(args.save_dir), 'test_saving'), logger=logger) model = DSH(m=args.m, config=args.config) model.cuda() optimizer = SGD(params=model.parameters(), lr=args.lr, momentum=0.9) # logger.info("optimizer inited") steps = _try_load(args, logger, model, optimizer) logger.info(str(args)) args.steps += steps dsh_loss = _DSH_loss(gamma=args.gamma) model.train() l2_regularization = _Regularization(model, args.l2_reg, p=2, logger=None) loss_sum = [] # logger.info("iterations") # iterations while True: # logger.info("update D") # 1. update D data_train.D = update_D(bi=data_train.BI, bs=data_train.BS, vec_bi=data_train.vec_bi, vec_bs=data_train.vec_bs) # logger.info("update BI/BS") # 2. update BI/BS feats_labels_sk, feats_labels_im = _extract_feats_sk_im( data=data_train, model=model, batch_size=args.batch_size) data_train.BI, data_train.BS = update_B(bi=data_train.BI, bs=data_train.BS, vec_bi=data_train.vec_bi, vec_bs=data_train.vec_bs, W=data_train.W, D=data_train.D, Fi=feats_labels_im[0], Fs=feats_labels_sk[0], lamb=args.lamb, gamma=args.gamma) # logger.info("update network parameters") # 3. update network parameters for _, (sketch, code_of_sketch, image, sketch_token, code_of_image) in enumerate(dataloader_train): sketch_feats, im_feats = model(sketch.cuda(), sketch_token.cuda(), image.cuda()) loss = dsh_loss(sketch_feats, im_feats, code_of_sketch.cuda(), code_of_image.cuda()) \ + l2_regularization() loss = loss / args.update_every loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) if (steps + 1) % args.update_every == 0: optimizer.step() optimizer.zero_grad() loss_sum.append(float(loss.item() * args.update_every)) if (steps + 1) % args.save_every == 0: _test_and_save(steps=steps, optimizer=optimizer, data_test=data_test, model=model, logger=logger, args=args, loss_sum=loss_sum) data_train.save_params() if (steps + 1) % args.print_every == 0: loss_sum = [np.mean(loss_sum)] logger.info('step: {}, loss: {}'.format(steps, loss_sum[0])) steps += 1 if steps >= args.steps: break dr_dec(optimizer=optimizer, args=args) if steps >= args.steps: break
def train(args): # srun -p gpu --gres=gpu:1 --output=d3shape_sketchy.out python main_d3shape.py --steps 50000 --print_every 200 --npy_dir 0 --save_every 1000 --batch_size 8 --dataset sketchy --save_dir d3shape_sketchy sketch_folder, imsk_folder, train_class, test_class = _parse_args_paths( args) data_train = D3Shape_dataloader(folder_sk=sketch_folder, clss=train_class, folder_nps=args.npy_dir, folder_imsk=imsk_folder, normalize01=False, doaug=False) dataloader_train = DataLoader(dataset=data_train, batch_size=args.batch_size, shuffle=False) data_test = D3Shape_dataloader(folder_sk=sketch_folder, clss=test_class, folder_nps=args.npy_dir, folder_imsk=imsk_folder, normalize01=False, doaug=False) model = D3Shape() model.cuda() optimizer = Adam(params=model.parameters(), lr=args.lr) logger = make_logger(join(mkdir(args.save_dir), curr_time_str() + '.log')) steps = _try_load(args, logger, model, optimizer) logger.info(str(args)) args.steps += steps d3shape_loss = _D3Shape_loss(cp=args.cp, cn=args.cn) model.train() l2_regularization = _Regularization(model, args.l2_reg, p=2, logger=None) while True: loss_sum = [] for _, (sketch1, imsk1, sketch2, imsk2, is_same) in enumerate(dataloader_train): optimizer.zero_grad() sketch1_feats, imsk1_feats = model(sketch1.cuda(), imsk1.cuda()) sketch2_feats, imsk2_feats = model(sketch2.cuda(), imsk2.cuda()) loss = d3shape_loss(sketch1_feats, imsk1_feats, sketch2_feats, imsk2_feats, is_same.cuda()) \ + l2_regularization() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() loss_sum.append(float(loss.item())) if (steps + 1) % args.save_every == 0: model.eval() n = 50 skip = 1 start_cpu_t = time.time() feats_labels_sk = _extract_feats(data_test, lambda sk: model(sk, None)[0], SK, skip=skip, batch_size=args.batch_size) feats_labels_imsk = _extract_feats( data_test, lambda imsk: model(None, imsk)[0], IMSK, skip=skip, batch_size=args.batch_size) pre, mAP = _eval(feats_labels_sk, feats_labels_imsk, n) logger.info( "Precision@{}: {}, mAP@{}: {}".format(n, pre, n, mAP) + " " + 'step: {}, loss: {}, (eval cpu time: {}s)'.format( steps, np.mean(loss_sum), time.time() - start_cpu_t)) torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'steps': steps, 'args': args }, save_fn(args.save_dir, steps, pre, mAP)) model.train() if (steps + 1) % args.print_every == 0: logger.info('step: {}, loss: {}'.format( steps, np.mean(loss_sum))) loss_sum = [] steps += 1 if steps >= args.steps: break if steps >= args.steps: break
def train(args): # srun -p gpu --gres=gpu:1 --exclusive --output=san10.out python main_san.py --epochs 50000 --print_every 500 --save_every 2000 --batch_size 96 --dataset sketchy --margin 10 --npy_dir 0 --save_dir san_sketchy10 # srun -p gpu --gres=gpu:1 --exclusive --output=san1.out python main_san.py --epochs 50000 --print_every 500 --save_every 2000 --batch_size 96 --dataset sketchy --margin 1 --npy_dir 0 --save_dir san_sketchy1 # srun -p gpu --gres=gpu:1 --output=san_sketchy03.out python main_san.py --epochs 30000 --print_every 200 --save_every 3000 --batch_size 96 --dataset sketchy --margin 0.3 --npy_dir 0 --save_dir san_sketchy03 --lr 0.0001 sketch_folder, image_folder, path_semantic, train_class, test_class = _parse_args_paths( args) if DEBUG: args.back_bone = 'default' args.npy_dir = NPY_FOLDER_SKETCHY args.ni_path = PATH_NAMES args.print_every = 1 args.save_every = 5 args.paired = True args.epochs = 20000 # args.lr = 0.001 args.sz = 32 # args.l2_reg = 0.0001 args.back_bone = 'default' args.batch_size = 32 args.h = 500 test_class = train_class[5:7] train_class = train_class[:5] logger = make_logger(join(mkdir(args.save_dir), curr_time_str() + '.log')) data_train = CMT_dataloader( folder_sk=sketch_folder, clss=train_class, folder_nps=args.npy_dir, path_semantic=path_semantic, paired=args.paired, names=args.ni_path, folder_im=image_folder, normalize01=False, doaug=False, logger=logger, sz=None if args.back_bone == 'vgg' else args.sz) dataloader_train = DataLoader(dataset=data_train, batch_size=args.batch_size, shuffle=True) data_test = CMT_dataloader(folder_sk=sketch_folder, clss=test_class, folder_nps=args.npy_dir, path_semantic=path_semantic, folder_im=image_folder, normalize01=False, doaug=False, logger=logger, sz=None if args.back_bone == 'vgg' else args.sz) model = CMT(d=data_train.d(), h=args.h, back_bone=args.back_bone, batch_normalization=args.bn, sz=args.sz) model.cuda() if not args.ft: model.fix_vgg() optimizer = SGD(params=model.parameters(), lr=args.lr, momentum=0.6) epochs = _try_load(args, logger, model, optimizer) logger.info(str(args)) args.epochs += epochs cmt_loss = _CMT_loss() model.train() l2_regularization = _Regularization(model, args.l2_reg, p=2, logger=None) loss_sum = [[0], [0]] logger.info( "Start training:\n train_classes: {}\n test_classes: {}".format( train_class, test_class)) _test_and_save(epochs=epochs, optimizer=optimizer, data_test=data_test, model=model, logger=logger, args=args, loss_sum=loss_sum) while True: for mode, get_feat in [[IM, lambda data: model(im=data)], [SK, lambda data: model(sk=data)]]: data_train.mode = mode for _, (data, semantics) in enumerate(dataloader_train): # Skip one-element batch in consideration of batch normalization if data.shape[0] == 1: continue # print(data.shape) optimizer.zero_grad() loss = cmt_loss(get_feat(data.cuda()), semantics.cuda()) \ + l2_regularization() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() loss_sum[mode].append(float(loss.item())) epochs += 1 dr_dec(optimizer=optimizer, args=args) if (epochs + 1) % args.save_every == 0: _test_and_save(epochs=epochs, optimizer=optimizer, data_test=data_test, model=model, logger=logger, args=args, loss_sum=loss_sum) if (epochs + 1) % args.print_every == 0: logger.info('epochs: {}, loss_sk: {}, loss_im: {},'.format( epochs, np.mean(loss_sum[SK]), np.mean(loss_sum[IM]))) loss_sum = [[], []] if epochs >= args.epochs: break
def train(args): writer = SummaryWriter() logger = make_logger(args.log_file) if args.zs: packed = args.packed_pkl_zs else: packed = args.packed_pkl_nozs logger.info('Loading the data ...') data = CMDTrans_data(args.sketch_dir, args.image_dir, args.stats_file, args.embedding_file, packed, args.preprocess_data, args.raw_data, zs=args.zs, sample_time=1, cvae=True, paired=False, cut_part=False) dataloader_train = DataLoader(dataset=data, num_workers=args.num_worker, \ batch_size=args.batch_size, shuffle=args.shuffle) logger.info('Training sketch size: {}'.format( len(data.path2class_sketch.keys()))) logger.info('Training image size: {}'.format( len(data.path2class_image.keys()))) logger.info('Testing sketch size: {}'.format( len(data.path2class_sketch_test.keys()))) logger.info('Testing image size: {}'.format( len(data.path2class_image_test.keys()))) logger.info('Building the model ...') model = Regressor(args.raw_size, args.hidden_size, dropout_prob=args.dropout, logger=logger) logger.info('Building the optimizer ...') optimizer = Adam(params=model.parameters(), lr=args.lr, betas=(0.5, 0.999)) l1_regularization = _Regularization(model, args.l1_weight, p=1, logger=logger) l2_regularization = _Regularization(model, args.l2_weight, p=2, logger=logger) if args.start_from is not None: logger.info('Loading pretrained model from {} ...'.format( args.start_from)) ckpt = torch.load(args.start_from, map_location='cpu') model.load_state_dict(ckpt['model']) optimizer.load_state_dict(ckpt['optimizer']) if args.gpu_id != -1: model.cuda(args.gpu_id) optimizer.zero_grad() loss_tri_acm = 0. loss_l1_acm = 0. loss_l2_acm = 0. batch_acm = 0 global_step = 0 best_precision = 0. best_iter = 0 patience = args.patience logger.info('Hyper-Parameter:') logger.info(args) logger.info('Model Structure:') logger.info(model) logger.info('Begin Training !') while True: if patience <= 0: break for sketch_batch, image_p_batch, image_n_batch, _semantics_batch in dataloader_train: sketch_batch = sketch_batch.float() image_p_batch = image_p_batch.float() image_n_batch = image_n_batch.float() if global_step % args.print_every == 0 % args.print_every and global_step and batch_acm % args.cum_num == 0: logger.info('*** Iter {} ***'.format(global_step)) logger.info(' Loss/Triplet {:.3}'.format( loss_tri_acm / args.print_every / args.cum_num)) logger.info(' Loss/L1 {:.3}'.format( loss_l1_acm / args.print_every / args.cum_num)) logger.info(' Loss/L2 {:.3}'.format( loss_l2_acm / args.print_every / args.cum_num)) loss_tri_acm = 0. loss_l1_acm = 0. loss_l2_acm = 0. if global_step % args.save_every == 0 % args.save_every and batch_acm % args.cum_num == 0 and global_step: if not os.path.exists(args.save_dir): os.mkdir(args.save_dir) torch.save( { 'args': args, 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, '{}/Iter_{}.pkl'.format(args.save_dir, global_step)) ### Evaluation model.eval() image_label = list() image_feature = list() for image, label in data.load_test_images( batch_size=args.batch_size): image = image.float() if args.gpu_id != -1: image = image.cuda(args.gpu_id) image_label += label tmp_feature = model.inference_image( image).cpu().detach().numpy() image_feature.append(tmp_feature) image_feature = np.vstack(image_feature) sketch_label = list() sketch_feature = list() for sketch, label in data.load_test_sketch( batch_size=args.batch_size): sketch = sketch.float() if args.gpu_id != -1: sketch = sketch.cuda(args.gpu_id) sketch_label += label tmp_feature = model.inference_sketch( sketch).cpu().detach().numpy() sketch_feature.append(tmp_feature) sketch_feature = np.vstack(sketch_feature) Precision, mAP, = cal_matrics_single(image_feature, image_label, sketch_feature, sketch_label) writer.add_scalar('Precision_200/cosine', Precision, global_step) writer.add_scalar('mAP_200/cosine', mAP, global_step) logger.info('*** Evaluation Iter {} ***'.format(global_step)) logger.info(' Precision {:.3}'.format(Precision)) logger.info(' mAP {:.3}'.format(mAP)) if best_precision < Precision: patience = args.patience best_precision = Precision best_iter = global_step writer.add_scalar('Best/Precision_200', best_precision, best_iter) logger.info('Iter {}, Best Precision_200 {:.3}'.format( global_step, best_precision)) torch.save({'args':args, 'model':model.state_dict(), \ 'optimizer':optimizer.state_dict()}, '{}/Best.pkl'.format(args.save_dir)) else: patience -= 1 if patience <= 0: break model.train() batch_acm += 1 if global_step <= args.warmup_steps: update_lr(optimizer, args.lr * global_step / args.warmup_steps) if args.gpu_id != -1: sketch_batch = sketch_batch.cuda(args.gpu_id) image_p_batch = image_p_batch.cuda(args.gpu_id) image_n_batch = image_n_batch.cuda(args.gpu_id) loss = model(sketch_batch, image_p_batch, image_n_batch) loss_l1 = l1_regularization() loss_l2 = l2_regularization() loss_tri = loss.item() loss_l1_acm += (loss_l1.item() / args.l1_weight) loss_l2_acm += (loss_l2.item() / args.l2_weight) loss_tri_acm += loss_tri writer.add_scalar('Loss/Triplet', loss_tri, global_step) writer.add_scalar('Loss/Reg_l1', (loss_l1.item() / args.l1_weight), global_step) writer.add_scalar('Loss/Reg_l2', (loss_l2.item() / args.l2_weight), global_step) loss_ = 0 loss_ += loss loss_.backward() if batch_acm % args.cum_num == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() global_step += 1 optimizer.zero_grad()
def train(args): writer = SummaryWriter() logger = make_logger(args.log_file) if args.zs: packed = args.packed_pkl_zs else: packed = args.packed_pkl_nozs data = Siamese_dataloader(args.sketch_dir, args.image_dir, args.stats_file, packed, zs=args.zs) print(len(data)) dataloader_train = DataLoader(dataset=data, num_workers=args.num_worker, \ batch_size=args.batch_size, shuffle=args.shuffle) logger.info('Building the model ...') model = Siamese(args.margin, args.loss_type, args.distance_type, batch_normalization=False, from_pretrain=True, logger=logger) logger.info('Building the optimizer ...') #optimizer = Adam(params=model.parameters(), lr=args.lr) optimizer = SGD(params=model.parameters(), lr=args.lr, momentum=0.9) siamese_loss = _Siamese_loss() l1_regularization = _Regularization(model, 0.1, p=1, logger=logger) l2_regularization = _Regularization(model, 1e-4, p=2, logger=logger) if args.start_from is not None: logger.info('Loading pretrained model from {} ...'.format( args.start_from)) ckpt = torch.load(args.start_from, map_location='cpu') model.load_state_dict(ckpt['model']) optimizer.load_state_dict(ckpt['optimizer']) if args.gpu_id != -1: model.cuda(args.gpu_id) batch_acm = 0 global_step = 0 loss_siamese_acm, sim_acm, dis_sim_acm, loss_l1_acm, loss_l2_acm = 0., 0., 0., 0., 0., best_precision = 0. best_iter = 0 patience = args.patience logger.info('Hyper-Parameter:') logger.info(args) logger.info('Model Structure:') logger.info(model) logger.info('Begin Training !') while True: if patience <= 0: break for sketch_batch, image_batch, label_batch in dataloader_train: if global_step % args.print_every == 0 % args.print_every and global_step and batch_acm % args.cum_num == 0: logger.info('Iter {}, Loss/siamese {:.3f}, Loss/l1 {:.3f}, Loss/l2 {:.3f}, Siamese/sim {:.3f}, Siamese/dis_sim {:.3f}'.format(global_step, \ loss_siamese_acm/args.print_every/args.cum_num, \ loss_l1_acm/args.print_every/args.cum_num, \ loss_l2_acm/args.print_every/args.cum_num, \ sim_acm/args.print_every/args.cum_num, \ dis_sim_acm/args.print_every/args.cum_num)) loss_siamese_acm, sim_acm, dis_sim_acm, loss_l1_acm, loss_l2_acm = 0., 0., 0., 0., 0., if global_step % args.save_every == 0 % args.save_every and batch_acm % args.cum_num == 0 and global_step: if not os.path.exists(args.save_dir): os.mkdir(args.save_dir) torch.save({'args':args, 'model':model.state_dict(), \ 'optimizer':optimizer.state_dict()}, '{}/Iter_{}.pkl'.format(args.save_dir,global_step)) ### Evaluation model.eval() image_label = list() image_feature = list() for image, label in data.load_test_images( batch_size=args.batch_size): image = image.cuda(args.gpu_id) image_label += label tmp_feature = model.get_feature( image).cpu().detach().numpy() image_feature.append(tmp_feature) image_feature = np.vstack(image_feature) sketch_label = list() sketch_feature = list() for sketch, label in data.load_test_sketch( batch_size=args.batch_size): sketch = sketch.cuda(args.gpu_id) sketch_label += label tmp_feature = model.get_feature( sketch).cpu().detach().numpy() sketch_feature.append(tmp_feature) sketch_feature = np.vstack(sketch_feature) dists_cosine = cdist(image_feature, sketch_feature, 'cosine') print(dists_cosine.shape) dists_euclid = cdist(image_feature, sketch_feature, 'euclidean') rank_cosine = np.argsort(dists_cosine, 0) rank_euclid = np.argsort(dists_euclid, 0) for n in [5, 200]: ranksn_cosine = rank_cosine[:n, :].T ranksn_euclid = rank_euclid[:n, :].T classesn_cosine = np.array([[image_label[i] == sketch_label[r] \ for i in ranksn_cosine[r]] for r in range(len(ranksn_cosine))]) classesn_euclid = np.array([[image_label[i] == sketch_label[r] \ for i in ranksn_euclid[r]] for r in range(len(ranksn_euclid))]) precision_cosine = np.mean(classesn_cosine) precision_euclid = np.mean(classesn_euclid) writer.add_scalar('Precision_{}/cosine'.format(n), precision_cosine, global_step) writer.add_scalar('Precision_{}/euclid'.format(n), precision_euclid, global_step) logger.info('Iter {}, Precision_{}/cosine {}'.format( global_step, n, precision_cosine)) logger.info('Iter {}, Precision_{}/euclid {}'.format( global_step, n, precision_euclid)) if best_precision < precision_cosine: patience = args.patience best_precision = precision_cosine best_iter = global_step writer.add_scalar('Best/Precision_200', best_precision, best_iter) logger.info('Iter {}, Best Precision_200 {}'.format( global_step, best_precision)) torch.save({'args':args, 'model':model.state_dict(), \ 'optimizer':optimizer.state_dict()}, '{}/Best.pkl'.format(args.save_dir)) else: patience -= 1 if patience <= 0: break model.train() batch_acm += 1 if global_step <= args.warmup_steps: update_lr(optimizer, args.lr * global_step / args.warmup_steps) """ #code for testing if the images and the sketches are corresponding to each other correctly for i in range(args.batch_size): sk = sketch_batch[i].numpy().reshape(224, 224, 3) im = image_batch[i].numpy().reshape(224, 224, 3) print(label[i]) ims = np.vstack((np.uint8(sk), np.uint8(im))) cv2.imshow('test', ims) cv2.waitKey(3000) """ sketch = sketch_batch.cuda(args.gpu_id) image = image_batch.cuda(args.gpu_id) label = label_batch.float().cuda(args.gpu_id) optimizer.zero_grad() sketch_feature, image_feature = model(sketch, image) loss_siamese, sim, dis_sim = siamese_loss( sketch_feature, image_feature, label, args.margin, loss_type=args.loss_type, distance_type=args.distance_type) loss_l1 = l1_regularization() loss_l2 = l2_regularization() loss_siamese_acm += loss_siamese.item() sim_acm += sim.item() dis_sim_acm += dis_sim.item() loss_l1_acm += loss_l1.item() loss_l2_acm += loss_l2.item() writer.add_scalar('Loss/Siamese', loss_siamese.item(), global_step) writer.add_scalar('Loss/L1', loss_l1.item(), global_step) writer.add_scalar('Loss/L2', loss_l2.item(), global_step) writer.add_scalar('Siamese/Similar', sim.item(), global_step) writer.add_scalar('Siamese/Dis-Similar', dis_sim.item(), global_step) loss = loss_siamese + loss_l2 loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) if batch_acm % args.cum_num == 0: optimizer.step() global_step += 1
def train(args): relu = nn.ReLU(inplace=True) writer = SummaryWriter() logger = make_logger(args.log_file) if args.zs: packed = args.packed_pkl_zs else: packed = args.packed_pkl_nozs logger.info('Loading the data ...') data = CMDTrans_data(args.sketch_dir, args.image_dir, args.stats_file, args.embedding_file, packed, args.preprocess_data, args.raw_data, zs=args.zs, sample_time=1, cvae=True, paired=True, cut_part=False, ranking=True, tu_berlin=args.tu_berlin, strong_pair=args.strong_pair) dataloader_train = DataLoader(dataset=data, num_workers=args.num_worker, batch_size=args.batch_size, shuffle=args.shuffle) logger.info('Training sketch size: {}'.format( len(data.path2class_sketch.keys()))) logger.info('Training image size: {}'.format( len(data.path2class_image.keys()))) logger.info('Testing sketch size: {}'.format( len(data.path2class_sketch_test.keys()))) logger.info('Testing image size: {}'.format( len(data.path2class_image_test.keys()))) logger.info('Building the model ...') model = CMDTrans_model(args.pca_size, args.raw_size, args.hidden_size, args.semantics_size, data.pretrain_embedding.float(), dropout_prob=args.dropout, fix_embedding=args.fix_embedding, seman_dist=args.seman_dist, triplet_dist=args.triplet_dist, margin1=args.margin1, margin2=args.margin2, logger=logger) logger.info('Building the optimizer ...') optimizer = Adam(params=model.parameters(), lr=args.lr, betas=(0.5, 0.999)) #optimizer = SGD(params=model.parameters(), lr=args.lr, momentum=0.9) l1_regularization = _Regularization(model, args.l1_weight, p=1, logger=logger) l2_regularization = _Regularization(model, args.l2_weight, p=2, logger=logger) if args.start_from is not None: logger.info('Loading pretrained model from {} ...'.format( args.start_from)) ckpt = torch.load(args.start_from, map_location='cpu') model.load_state_dict(ckpt['model']) optimizer.load_state_dict(ckpt['optimizer']) if args.gpu_id != -1: model.cuda(args.gpu_id) optimizer.zero_grad() # six design loss and two reg loss loss_triplet_acm = 0. loss_orth_acm = 0. loss_kl_acm = 0. loss_img_acm = 0. loss_ske_acm = 0. loss_l1_acm = 0. loss_l2_acm = 0. # loading batch and optimization step batch_acm = 0 global_step = 0 # best recoder best_precision = 0. best_iter = 0 patience = args.patience logger.info('Hyper-Parameter:') logger.info(args) logger.info('Model Structure:') logger.info(model) logger.info('Begin Training !') loss_weight = dict([('kl', 1.0), ('triplet', 1.0), ('orthogonality', 0.01), ('image', 1.0), ('sketch', 10.0)]) while True: if patience <= 0: break for sketch_batch, image_pair_batch, image_unpair_batch, image_n_batch in dataloader_train: if global_step % args.print_every == 0 % args.print_every and global_step and batch_acm % args.cum_num == 0: logger.info('*** Iter {} ***'.format(global_step)) logger.info(' Loss/Triplet {:.3}'.format( loss_triplet_acm / args.print_every / args.cum_num)) logger.info(' Loss/Orthogonality {:.3}'.format( loss_orth_acm / args.print_every / args.cum_num)) logger.info(' Loss/KL {:.3}'.format( loss_kl_acm / args.print_every / args.cum_num)) logger.info(' Loss/Image {:.3}'.format( loss_img_acm / args.print_every / args.cum_num)) logger.info(' Loss/Sketch {:.3}'.format( loss_ske_acm / args.print_every / args.cum_num)) logger.info(' Loss/L1 {:.3}'.format( loss_l1_acm / args.print_every / args.cum_num)) logger.info(' Loss/L2 {:.3}'.format( loss_l2_acm / args.print_every / args.cum_num)) loss_triplet_acm = 0. loss_orth_acm = 0. loss_kl_acm = 0. loss_img_acm = 0. loss_ske_acm = 0. loss_l1_acm = 0. loss_l2_acm = 0. if global_step % args.save_every == 0 % args.save_every and batch_acm % args.cum_num == 0 and global_step: if not os.path.exists(args.save_dir): os.mkdir(args.save_dir) torch.save( { 'args': args, 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, '{}/Iter_{}.pkl'.format(args.save_dir, global_step)) ### Evaluation model.eval() image_label = list() image_feature1 = list() # S image_feature2 = list() # G for image, label in data.load_test_images( batch_size=args.batch_size): image = relu(image) if args.gpu_id != -1: image = image.float().cuda(args.gpu_id) image_label += label tmp_feature1 = model.inference_structure( image, 'image').detach() # S tmp_feature2 = image.detach() # G image_feature1.append(tmp_feature1) image_feature2.append(tmp_feature2) image_feature1 = torch.cat(image_feature1) image_feature2 = torch.cat(image_feature2) sketch_label = list() sketch_feature1 = list() # S sketch_feature2 = list() # G for sketch, label in data.load_test_sketch( batch_size=args.batch_size): sketch = relu(sketch) if args.gpu_id != -1: sketch = sketch.float().cuda(args.gpu_id) sketch_label += label tmp_feature1 = model.inference_structure( sketch, 'sketch').detach() # S tmp_feature2 = model.inference_generation( sketch).detach() # G sketch_feature1.append(tmp_feature1) sketch_feature2.append(tmp_feature2) sketch_feature1 = torch.cat(sketch_feature1) sketch_feature2 = torch.cat(sketch_feature2) dists_cosine1 = cosine_distance( image_feature1, sketch_feature1).cpu().detach().numpy() dists_cosine2 = cosine_distance( image_feature2, sketch_feature2).cpu().detach().numpy() Precision_list, mAP_list, lambda_list, Precision_c, mAP_c = \ cal_matrics(dists_cosine1, dists_cosine2, image_label, sketch_label) logger.info('*** Evaluation Iter {} ***'.format(global_step)) for idx, item in enumerate(lambda_list): writer.add_scalar('Precision_200/{}'.format(item), Precision_list[idx], global_step) writer.add_scalar('mAP_200/{}'.format(item), mAP_list[idx], global_step) logger.info(' Precision/{} {:.3}'.format( item, Precision_list[idx])) logger.info(' mAP/{} {:.3}'.format( item, mAP_list[idx])) writer.add_scalar('Precision_200/Compare', Precision_c, global_step) writer.add_scalar('mAP_200/Compare', mAP_c, global_step) logger.info( ' Precision/Compare {:.3}'.format(Precision_c)) logger.info(' mAP/Compare {:.3}'.format(mAP_c)) Precision_list.append(Precision_c) Precision = max(Precision_list) if best_precision < Precision: patience = args.patience best_precision = Precision best_iter = global_step writer.add_scalar('Best/Precision_200', best_precision, best_iter) logger.info( '=== Iter {}, Best Precision_200 {:.3} ==='.format( global_step, best_precision)) torch.save({'args':args, 'model':model.state_dict(), \ 'optimizer':optimizer.state_dict()}, '{}/Best.pkl'.format(args.save_dir)) else: patience -= 1 if patience <= 0: break model.train() batch_acm += 1 if global_step <= args.warmup_steps: update_lr(optimizer, args.lr * global_step / args.warmup_steps) if args.gpu_id != -1: sketch_batch = relu(sketch_batch).float().cuda(args.gpu_id) image_pair_batch = relu(image_pair_batch).float().cuda( args.gpu_id) image_unpair_batch = relu(image_unpair_batch).float().cuda( args.gpu_id) image_n_batch = relu(image_n_batch).float().cuda(args.gpu_id) loss = model(sketch_batch, image_pair_batch, image_unpair_batch, image_n_batch) loss_l1 = l1_regularization() loss_l2 = l2_regularization() loss_kl = loss['kl'].item() loss_orth = loss['orthogonality'].item() loss_triplet = loss['triplet'].item() loss_img = loss['image'].item() loss_ske = loss['sketch'].item() loss_l1_acm += (loss_l1.item() / args.l1_weight) loss_l2_acm += (loss_l2.item() / args.l2_weight) loss_kl_acm += loss_kl loss_orth_acm += loss_orth loss_triplet_acm += loss_triplet loss_img_acm += loss_img loss_ske_acm += loss_ske writer.add_scalar('Loss/KL', loss_kl, global_step) writer.add_scalar('Loss/Orthogonality', loss_orth, global_step) writer.add_scalar('Loss/Triplet', loss_triplet, global_step) writer.add_scalar('Loss/Image', loss_img, global_step) writer.add_scalar('Loss/Sketch', loss_ske, global_step) writer.add_scalar('Loss/Reg_l1', (loss_l1.item() / args.l1_weight), global_step) writer.add_scalar('Loss/Reg_l2', (loss_l2.item() / args.l2_weight), global_step) loss_ = 0 loss_ += loss['image'] * loss_weight['image'] loss_ += loss['sketch'] * loss_weight['sketch'] loss_ += loss['triplet'] * loss_weight['triplet'] loss_ += loss['kl'] * loss_weight['kl'] #loss_ += loss['orthogonality']*loss_weight['orthogonality'] loss_.backward() if batch_acm % args.cum_num == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() global_step += 1 optimizer.zero_grad()