def contrastive_loss( self, input, feat, target, conf="none", thresh=0.1, distmetric="l2" ): softmax = nn.Softmax(dim=1) target = softmax(target.view(-1, target.shape[-1])).view(target.shape) if conf == "max": weight = torch.max(target, axis=1).values w = torch.tensor( [i for i, x in enumerate(weight) if x > thresh], dtype=torch.long ).to(self.device) elif conf == "entropy": weight = torch.sum(-torch.log(target + 1e-6) * target, dim=1) weight = 1 - weight / np.log(weight.size(-1)) w = torch.tensor( [i for i, x in enumerate(weight) if x > thresh], dtype=torch.long ).to(self.device) input_x = input[w] feat_x = feat[w] batch_size = input_x.size()[0] if batch_size == 0: return 0 index = torch.randperm(batch_size).to(self.device) input_y = input_x[index, :] feat_y = feat_x[index, :] argmax_x = torch.argmax(input_x, dim=1) argmax_y = torch.argmax(input_y, dim=1) agreement = torch.FloatTensor( [1 if x == True else 0 for x in argmax_x == argmax_y] ).to(self.device) criterion = ContrastiveLoss(margin=1.0, metric=distmetric) loss, dist_sq, dist = criterion(feat_x, feat_y, agreement) return loss
def train(train_loader, model, optimizer, epoch): model.train() # loss_function = DlibLoss() loss_function = ContrastiveLoss() # pbar = tqdm(enumerate(train_loader)) for batch_idx, (data_a, data_p, c) in enumerate(train_loader): data_a, data_p, c = data_a.cuda(), data_p.cuda(), c.cuda() data_a, data_p, c = Variable(data_a), Variable(data_p), Variable(c) out_a, out_p = model(data_a), model(data_p) loss = loss_function(out_a, out_p, c) optimizer.zero_grad() loss.backward() optimizer.step() # update the optimizer learning rate adjust_learning_rate(optimizer) plotter.plot('loss', 'train', epoch * config.n_batch + batch_idx, loss.data[0]) if (epoch * config.n_batch + batch_idx) % config.log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data_a), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.data[0])) torch.save({ 'epoch': epoch + 1, 'state_dict': model.state_dict() }, '{}/checkpoint_{}.pth'.format(config.log_dir, epoch))
def on_recv_do_train(message): logging.debug('on_recv_do_train') alpha = float(message['alpha']) # get patch image path image_dir = da.get_image_folder() os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1' device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') criterion = nn.CrossEntropyLoss() contrasive = ContrastiveLoss(margin=2.0)# .to(device) crop_size = 3*256 transform = transforms.Compose([ transforms.CenterCrop((crop_size, crop_size)), transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))]) model = SimDenseNet2(growthRate=32, depth=24, reduction=0.5, bottleneck=True, nClasses=3).to(device) model = nn.DataParallel(model) model_dir = os.path.join('', 'static/model/') model_path = model_dir + '21_model22.pth' model.load_state_dict(torch.load(model_path)) # model = model.module test_dir = os.path.join('', 'static/images/pre/images/') trainer = Trainer(model=model, minmax_epochs=(10, 30), alpha=alpha, batch_size=16, test_dir=test_dir) # Train with hp1 200 patches # trainer.data_dir = image_dir # trainer.num_samples = 5 # trainer.train() whole_width = da.get_image_size()[0] whole_height = da.get_image_size()[1] dir_full_patches = da.get_image_folder() + 'whole_patches/' # drop out /images. cmap = trainer.viz_WSI_ft(whole_path=dir_full_patches, whole_wh=(whole_width, whole_height), alpha=alpha, dis_th=0.7) # Use only softmax logging.debug(cmap) logging.debug('train() complete') # set patch information da.set_patches(cmap) image_size = da.get_image_size() patches_info = da.get_patches() info = { "image_size": image_size, "patches_info": patches_info } info_json = json.dumps(info) send_msg('patches_info', info_json)
def __init__(self, model, minmax_epochs, alpha, batch_size, test_dir, f_lambda=1.0): # MISC self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.prev_val_f_loss = np.inf self.curr_val_f_loss = 0 self.test_dir = test_dir self.alpha = alpha # Archive for saving mean of features self.archive = {} self.archive['hp'] = {} self.archive['nor'] = {} self.archive['ta'] = {} self.archive['hp']['sum'] = np.zeros(224) self.archive['hp']['count'] = 0 # self.archive['hp']['avg'] = 0 self.archive['nor']['sum'] = np.zeros(224) self.archive['nor']['count'] = 0 # self.archive['nor']['avg'] = 0 self.archive['ta']['sum'] = np.zeros(224) self.archive['ta']['count'] = 0 # self.archive['ta']['avg'] = 0 # Model & Optimizer self.model = model.to(self.device) self.optimizer = optim.Adam(self.model.parameters(), betas=(0.5, 0.99)) self.model = nn.DataParallel(model) self.contrasive = ContrastiveLoss(margin=2.0) # .to(device) self.criterion = nn.CrossEntropyLoss() # Training configuration self.transform = transforms.Compose([ transforms.CenterCrop((768, 768)), transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) self._num_samples = None self._data_dir = None self.batch_size = batch_size self.min_epochs = minmax_epochs[0] self.max_epochs = minmax_epochs[1] self.batch_size = batch_size self.f_lambda = f_lambda
def eval_fn(data_loader, model, device, test=False): """ Evaluation function to predict on the test set """ # Set model to evaluation mode # I.e., turn off dropout and set batchnorm to use overall mean and variance (from training), rather than batch level mean and variance # Reference: https://github.com/pytorch/pytorch/issues/5406 model.eval() true_labels = [] pred_labels = [] # Turns off gradient calculations (https://datascience.stackexchange.com/questions/32651/what-is-the-use-of-torch-no-grad-in-pytorch) with torch.no_grad(): tk0 = tqdm(data_loader, total=len(data_loader)) # Make predictions and calculate loss / acc, f1 score for each batch for bi, batch in enumerate(tk0): query_sequence, question_sequence, query_pooled, question_pooled = run_one_step( batch, model, device) labels = batch["label"].to(device) # Calculate loss for the batch distance = cal_distance(query_pooled, question_pooled, cos=False) # Calculate batch loss based on CrossEntropy loss_fn = ContrastiveLoss(margin=1) loss = loss_fn(distance, labels) # Apply softmax to the predicted logits # This converts the "logits" to "probability-like" scores pred_label = [1 if d > config.THRESHOLDE else 0 for d in distance] labels = labels.cpu().numpy().tolist() pred_labels.extend(pred_label) true_labels.extend(labels) acc, f1 = calculate_metrics_score( label=labels, pred_label=pred_label, ) # Print the running average loss and acc and f1 score tk0.set_postfix(loss=loss.item(), acc=acc, f1=f1) acc, f1, auc = calculate_metrics_score(label=true_labels, pred_label=pred_labels, cal_auc=True) logger.info(f"acc = {acc}, f1 = {f1}, auc={auc}") return acc, f1, auc
def train_fn(data_loader, model, optimizer, device, scheduler=None, threshold=None): """ Trains the bert model on the twitter data """ # Set model to training mode (dropout + sampled batch norm is activated) model.train() # Set tqdm to add loading screen and set the length tk0 = tqdm(data_loader, total=len(data_loader)) # Train the model on each batch for bi, batch in enumerate(tk0): query_sequence, question_sequence, query_pooled, question_pooled = run_one_step( batch, model, device) labels = batch["label"].to(device) distance = cal_distance(query_pooled, question_pooled, cos=False) # Calculate batch loss based on CrossEntropy loss_fn = ContrastiveLoss(margin=1) loss = loss_fn(distance, labels) # Calculate gradients based on loss loss.backward() # Adjust weights based on calculated gradients optimizer.step() # Update scheduler scheduler.step() pred_labels = [1 if d > threshold else 0 for d in distance] # Calculate the jaccard score based on the predictions for this batch acc, f1 = calculate_metrics_score( label=labels.cpu().numpy(), pred_label=np.array(pred_labels), ) # Print the average loss and jaccard score at the end of each batch tk0.set_postfix(loss=loss.item(), acc=acc, f1=f1)
def train(self,): LOGGER.info('\n---------------- Train Starting ----------------') # Load training/validation data LOGGER.info('Load training/validation data') LOGGER.info('------------------------------') train_dataset=self.load_dataset("train") val_dataset=self.load_dataset("val") # LOGGER.info("start build catch repr for train") # n_hidden = 128 # self.model = NN(768, n_hidden, 768).to(self.device) criterion = ContrastiveLoss() optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3) n_epcoh = self.exp_config.epochs print_every = 5000 plot_every_n_batch = self.exp_config.plot_every_n_batch batch_size = self.exp_config.batch_size import random # Keep track of losses for plotting current_loss = 0 all_losses = [] batch_x=[] batch_i=1 sentence_tensor = [] catch_tensor = [] self.model.train() # w = self.model.nn1.weight.data.clone() for epoch in range(n_epcoh): LOGGER.info("epoch {} starts".format(epoch)) for i, case in enumerate(tqdm(train_dataset["all_cases"][:self.exp_config.iter_per_epoch])): text_idx = self.model.tokenizer(train_dataset["case_sentences"][case], truncation=True, return_tensors="pt", padding='max_length', max_length=512).to(self.device) last_hidden_state, pooler_output = self.model.encoder(**text_idx) sentence_tensor.append(pooler_output) catchphrase_id = randomChoice(train_dataset["case_catchphrases"][case]) catchphrase = train_dataset["idx_catchphrases"][catchphrase_id] text_idx = self.model.tokenizer(catchphrase, truncation=True, return_tensors="pt", padding='max_length', max_length=18).to(self.device) last_hidden_state, pooler_output = self.model.encoder(**text_idx) catch_tensor.append(pooler_output) # Print iter number, loss, name and guess if i % batch_size == 0: sentence_tensor = torch.cat(sentence_tensor, dim=0).to(self.device) catch_tensor = torch.cat(catch_tensor, dim=0).to(self.device) batch_loss = self.train_step(sentence_tensor, catch_tensor, criterion, optimizer) LOGGER.info("Record model.nn1.weight.data[0][:10]:") LOGGER.info(self.model.nn1.weight.data[0][:10]) batch_x.append(batch_i) batch_i+=1 LOGGER.info("loss = "+str(batch_loss/batch_size)) # current_loss += batch_loss sentence_tensor = [] catch_tensor = [] # # Add current loss avg to list of losses # if i % plot_every_n_batch * batch_size == 0: all_losses.append(batch_loss/batch_size) # current_loss = 0 self.plot_loss(batch_x,all_losses) self.evaluate() torch.save({ 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': all_losses[-1], }, os.path.join(self.exp_config.checkpoint_path, "model{}.pt".format(strftime("%Y_%m_%d_%H_%M_%S", gmtime())))) LOGGER.info("checkpoint saved")
def main(netname, nepoch, lossname, opt, lr, dircheckpoint, dirdataset, device, envplotter): print(20 * "-") print("netname:", netname) print('n_epoch:', nepoch) print("loss:", lossname) print("opt:", opt) print("lr:", lr) print("dircheckpoint:", dircheckpoint) print("dirDataset:", dirdataset) print("device:", device) print('envplotter:', envplotter) print(20 * "-") Config.train_number_epochs = nepoch networks = { "alexnet": SketchNetwork, "resnet": SketchNetworkResnet, "vgg": SketchNetworkVGG } net = networks[netname]() #net = nn.DataParallel(net) net = net.to(device) image_size = 224 if netname == "inception": image_size = 299 print(image_size) criterion_triplet = nn.TripletMarginLoss(margin=1.0) criterion_contrast = ContrastiveLoss() criterion_CE = nn.CrossEntropyLoss() losses = { "triplet": lambda x, y, z: criterion_triplet(x, y, z), "contrast": lambda x, y, z: (criterion_contrast(x, y, torch.ones(Config.train_batch_size).to(device)) + criterion_contrast( x, z, -1 * torch.ones(Config.train_batch_size).to(device))), "crossEntropyLoss": lambda out0, out1: (criterion_CE( torch.cat((out0, out1)), torch.cat( (torch.zeros(Config.train_batch_size), torch.ones(Config.train_batch_size))).to(device).long())) } criterion = losses[lossname] uses_triple = False if lossname == "triplet": uses_triple = True if opt != 'adm': optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9) else: optimizer = optim.Adam(net.parameters(), lr=lr) plotter = VisdomLinePlotter(env_name=envplotter, port=8097) epoch_loss = 0 valid_epoch_loss = 0 best_loss = -1 num_batch = 1 num_batch_val = 1 sh.mkdir("-p", dircheckpoint) files_checkpoints = np.array( sorted(glob.glob(dircheckpoint + "/*{}_{}*".format(netname, lossname)))) if (files_checkpoints.shape[0]): checkpoint = torch.load(files_checkpoints[-1]) net.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) print(files_checkpoints[-1]) transf = transforms.Compose([ transforms.RandomRotation((-45, 45), fill=(255, 255, 255, 1)), transforms.RandomHorizontalFlip(p=0.5), transforms.Resize((image_size, image_size)), transforms.ToTensor() ]) #DATA LOADERS sketch_dataset = SketchZoomDataset(data_sketch_root="data/", net=net, plotter=None, n=Config.train_data_n, stage="train", image_size=image_size, triplet=uses_triple, categories=[ "Airplane", "Bag", "Cap", "Car", "Chair", "Earphone", "Guitar", "Knife", "Lamp", "Laptop", "Motorbike", "Mug", "Pistol", "Rocket", "Skateboard", "Table" ], transform=transf, device=device) train_dataloader = DataLoader(sketch_dataset, shuffle=True, num_workers=0, batch_size=Config.train_batch_size, drop_last=True) sketch_dataset_test = SketchZoomDataset( data_sketch_root="data/", net=net, plotter=None, #plotter, n=Config.val_data_n, stage="test", image_size=image_size, triplet=uses_triple, categories=[ "Airplane", "Bag", "Cap", "Car", "Chair", "Earphone", "Guitar", "Knife", "Lamp", "Laptop", "Motorbike", "Mug", "Pistol", "Rocket", "Skateboard", "Table" ], transform=transf, device=device) validation_dataloader = DataLoader(sketch_dataset_test, shuffle=True, num_workers=0, batch_size=Config.val_batch_size, drop_last=True) for epoch in range(1, Config.train_number_epochs): print('EPOCH', epoch) valid_epoch_loss = 0 epoch_loss = 0 net.train() #TRAIN for i, data in enumerate(train_dataloader): sketch_dataset.net = net net.train() img0, img1, img2 = data img0, img1, img2 = Variable(img0).to(device), Variable(img1).to( device), Variable(img2).to(device) optimizer.zero_grad() if lossname != 'crossEntropyLoss': output1, output2, output3 = net(img0, img1, img2, img1.size()[0]) loss = criterion(output1, output2, output3) else: output1, output2, res0 = net.forward_two_binary( img0, img1, img1.size()[0]) output1, output3, res1 = net.forward_two_binary( img1, img2, img2.size()[0]) loss = criterion(res0, res1) distances_negativa = F.pairwise_distance(output1, output3) distances_negativa = distances_negativa.data.cpu().numpy().flatten( ) distances_positiva = F.pairwise_distance(output1, output2) distances_positiva = distances_positiva.data.cpu().numpy().flatten( ) loss.backward() epoch_loss += loss.item() optimizer.step() print('*' * 20) print(" nro Batch {} -- Current loss {}\n".format(i, loss.item())) num_batch = num_batch + 1 plotter.plot('Distance mean', str(0), num_batch, np.mean(distances_positiva), "Batchs") plotter.plot('Distance mean', str(1), num_batch, np.mean(distances_negativa), "Batchs") plotter.plot('Batchs loss', str(epoch), i + 1, loss.item(), "Batchs") if i != 0 and i % 9 == 0: torch.save( { 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, "{}/{}_{}_{}.pkl".format(dircheckpoint, "current_batch_checkpoints", netname, lossname)) print("save net, checkpoint Batch") del i, data, distances_negativa, distances_positiva, output1, output2, output3, img0, img1, img2 net.eval() with torch.no_grad(): sketch_dataset.net = net for i, data in enumerate(validation_dataloader): img0, img1, img2 = data img0, img1, img2 = Variable(img0).to(device), Variable( img1).to(device), Variable(img2).to(device) if lossname != 'crossEntropyLoss': output1, output2, output3 = net(img0, img1, img2, img1.size()[0]) loss = criterion(output1, output2, output3) else: output1, output2, res0 = net.forward_two_binary( img0, img1, img1.size()[0]) output1, output3, res1 = net.forward_two_binary( img0, img2, img2.size()[0]) loss = criterion(res0, res1) distances_negativa = F.pairwise_distance(output1, output3) distances_negativa = distances_negativa.data.cpu().numpy( ).flatten() distances_positiva = F.pairwise_distance(output1, output2) distances_positiva = distances_positiva.data.cpu().numpy( ).flatten() num_batch_val = num_batch_val + 1 plotter.plot('Distance mean Valid', str(0), num_batch_val, np.mean(distances_positiva[0]), "Batchs") plotter.plot('Distance mean Valid', str(1), num_batch_val, np.mean(distances_negativa[0]), "Batchs") print(" nro Valid Batch{} -- Valid loss {}\n".format( i + 1, loss.item())) valid_epoch_loss += loss.item() del i, data, distances_negativa, distances_positiva, output1, output2, output3, img0, img1, img2 #END TRAIN current_epoch_loss = epoch_loss / (Config.train_data_n // Config.train_batch_size) current_epoch_loss_val = valid_epoch_loss / (Config.val_data_n // Config.val_batch_size) print("Epoch number {}\n Current loss average {}\n".format( epoch, current_epoch_loss)) print("Epoch number {}\n Current loss val average {}\n".format( epoch, current_epoch_loss_val)) plotter.plot('Epochs loss ', 'train epoch', epoch, current_epoch_loss, "Epochs") plotter.plot('Epochs loss ', 'valid epoch', epoch, current_epoch_loss_val, "Epochs") #SAVE NET WITH BEST LOSS if (best_loss == -1 or current_epoch_loss < best_loss): torch.save( { 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, "{}/{}_{}_{}.pkl".format(dircheckpoint, "checkpoints", netname, lossname)) best_loss = current_epoch_loss print("save net, loss: {}".format(best_loss))
batch_size=args.test_batch) nfeat = train_loader.__iter__().__next__()['input_anchor']['x'].shape[1] print("NFEAT: ",nfeat) print("Model: ",args.model) print("Scheduler: On") if args.no_scheduler else print("Scheduler: Off") if not args.no_windowed and args.input_type=='RST': print("Window: On") elif args.model == 'gcn_cheby': model = Siamese_GeoChebyConv(nfeat=nfeat, nhid=args.hidden, nclass=1, dropout=args.dropout) criterion = ContrastiveLoss(args.loss_margin) elif args.model == 'gcn_cheby_bce': model = Siamese_GeoChebyConv_Read(nfeat=nfeat, nhid=args.hidden, nclass=1, dropout=args.dropout) criterion = BCEWithLogitsLoss() elif args.model == 'gcn_cheby_cos': model = Siamese_GeoCheby_Cos(nfeat=nfeat, nhid=args.hidden, nclass=1, dropout=args.dropout) criterion = ContrastiveCosineLoss(args.temperature).to(device)
y_loss = {} # loss history y_loss['train'] = [] y_loss['val'] = [] y_err = {} y_err['train'] = [] y_err['val'] = [] def l2_norm(v): fnorm = torch.norm(v, p=2, dim=1, keepdim=True) + 1e-6 v = v.div(fnorm.expand_as(v)) return v xhloss = ContrastiveLoss() def compute_loss(model, input_ids, attention_mask, crop, motion, nl_id, crop_id, label, warm): if opt.motion: visual_embeds, lang_embeds, predict_class_v, predict_class_l, predict_class_motion = model.forward( input_ids, attention_mask, crop, motion.cuda()) else: visual_embeds, lang_embeds, predict_class_v, predict_class_l = model.forward( input_ids, attention_mask, crop) #print(similarity.shape, predict_class_v.shape, predict_class_l.shape) #print(label.shape, nl_id.shape) #label = label.float() visual_embeds = l2_norm(visual_embeds)