コード例 #1
0
    def contrastive_loss(
        self, input, feat, target, conf="none", thresh=0.1, distmetric="l2"
    ):
        softmax = nn.Softmax(dim=1)
        target = softmax(target.view(-1, target.shape[-1])).view(target.shape)
        if conf == "max":
            weight = torch.max(target, axis=1).values
            w = torch.tensor(
                [i for i, x in enumerate(weight) if x > thresh], dtype=torch.long
            ).to(self.device)
        elif conf == "entropy":
            weight = torch.sum(-torch.log(target + 1e-6) * target, dim=1)
            weight = 1 - weight / np.log(weight.size(-1))
            w = torch.tensor(
                [i for i, x in enumerate(weight) if x > thresh], dtype=torch.long
            ).to(self.device)
        input_x = input[w]

        feat_x = feat[w]
        batch_size = input_x.size()[0]
        if batch_size == 0:
            return 0
        index = torch.randperm(batch_size).to(self.device)
        input_y = input_x[index, :]
        feat_y = feat_x[index, :]
        argmax_x = torch.argmax(input_x, dim=1)
        argmax_y = torch.argmax(input_y, dim=1)
        agreement = torch.FloatTensor(
            [1 if x == True else 0 for x in argmax_x == argmax_y]
        ).to(self.device)

        criterion = ContrastiveLoss(margin=1.0, metric=distmetric)
        loss, dist_sq, dist = criterion(feat_x, feat_y, agreement)

        return loss
コード例 #2
0
def train(train_loader, model, optimizer, epoch):
    model.train()
    # loss_function = DlibLoss()
    loss_function = ContrastiveLoss()
    # pbar = tqdm(enumerate(train_loader))

    for batch_idx, (data_a, data_p, c) in enumerate(train_loader):
        data_a, data_p, c = data_a.cuda(), data_p.cuda(), c.cuda()
        data_a, data_p, c = Variable(data_a), Variable(data_p), Variable(c)

        out_a, out_p = model(data_a), model(data_p)
        loss = loss_function(out_a, out_p, c)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update the optimizer learning rate
        adjust_learning_rate(optimizer)

        plotter.plot('loss', 'train', epoch * config.n_batch + batch_idx,
                     loss.data[0])

        if (epoch * config.n_batch + batch_idx) % config.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data_a), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data[0]))
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.state_dict()
    }, '{}/checkpoint_{}.pth'.format(config.log_dir, epoch))
コード例 #3
0
ファイル: server.py プロジェクト: kangyeolk/deep-labeling-vis
def on_recv_do_train(message):
	logging.debug('on_recv_do_train')

	alpha = float(message['alpha'])

	# get patch image path
	image_dir = da.get_image_folder()

	os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	criterion = nn.CrossEntropyLoss()
	contrasive = ContrastiveLoss(margin=2.0)# .to(device)

	crop_size = 3*256
	transform = transforms.Compose([
						transforms.CenterCrop((crop_size, crop_size)),
						transforms.Resize((256, 256)),
						transforms.ToTensor(),
						transforms.Normalize(mean=(0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))])

	model = SimDenseNet2(growthRate=32, depth=24, reduction=0.5, bottleneck=True, nClasses=3).to(device)
	model = nn.DataParallel(model)
	model_dir = os.path.join('', 'static/model/')
	model_path = model_dir + '21_model22.pth'
	model.load_state_dict(torch.load(model_path))

	# 
	model = model.module


	test_dir = os.path.join('', 'static/images/pre/images/')
	trainer = Trainer(model=model, minmax_epochs=(10, 30), alpha=alpha, batch_size=16, test_dir=test_dir)

	# Train with hp1 200 patches
	# trainer.data_dir = image_dir
	# trainer.num_samples = 5
	# trainer.train()

	whole_width = da.get_image_size()[0]
	whole_height = da.get_image_size()[1]
	dir_full_patches = da.get_image_folder() + 'whole_patches/' # drop out /images.
	cmap = trainer.viz_WSI_ft(whole_path=dir_full_patches, whole_wh=(whole_width, whole_height), alpha=alpha, dis_th=0.7) # Use only softmax
	logging.debug(cmap)
	logging.debug('train() complete')

	# set patch information
	da.set_patches(cmap)

	image_size = da.get_image_size()
	patches_info = da.get_patches()
	info = {
		"image_size": image_size,
		"patches_info": patches_info
	}
	info_json = json.dumps(info)
	send_msg('patches_info', info_json)
コード例 #4
0
    def __init__(self,
                 model,
                 minmax_epochs,
                 alpha,
                 batch_size,
                 test_dir,
                 f_lambda=1.0):

        # MISC
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.prev_val_f_loss = np.inf
        self.curr_val_f_loss = 0
        self.test_dir = test_dir
        self.alpha = alpha

        # Archive for saving mean of features
        self.archive = {}
        self.archive['hp'] = {}
        self.archive['nor'] = {}
        self.archive['ta'] = {}

        self.archive['hp']['sum'] = np.zeros(224)
        self.archive['hp']['count'] = 0
        # self.archive['hp']['avg'] = 0
        self.archive['nor']['sum'] = np.zeros(224)
        self.archive['nor']['count'] = 0
        # self.archive['nor']['avg'] = 0
        self.archive['ta']['sum'] = np.zeros(224)
        self.archive['ta']['count'] = 0
        # self.archive['ta']['avg'] = 0

        # Model & Optimizer
        self.model = model.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), betas=(0.5, 0.99))
        self.model = nn.DataParallel(model)
        self.contrasive = ContrastiveLoss(margin=2.0)  # .to(device)
        self.criterion = nn.CrossEntropyLoss()

        # Training configuration
        self.transform = transforms.Compose([
            transforms.CenterCrop((768, 768)),
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        self._num_samples = None
        self._data_dir = None
        self.batch_size = batch_size
        self.min_epochs = minmax_epochs[0]
        self.max_epochs = minmax_epochs[1]
        self.batch_size = batch_size
        self.f_lambda = f_lambda
コード例 #5
0
def eval_fn(data_loader, model, device, test=False):
    """
    Evaluation function to predict on the test set
    """
    # Set model to evaluation mode
    # I.e., turn off dropout and set batchnorm to use overall mean and variance (from training), rather than batch level mean and variance
    # Reference: https://github.com/pytorch/pytorch/issues/5406
    model.eval()
    true_labels = []
    pred_labels = []
    # Turns off gradient calculations (https://datascience.stackexchange.com/questions/32651/what-is-the-use-of-torch-no-grad-in-pytorch)
    with torch.no_grad():
        tk0 = tqdm(data_loader, total=len(data_loader))
        # Make predictions and calculate loss / acc, f1 score for each batch
        for bi, batch in enumerate(tk0):

            query_sequence, question_sequence, query_pooled, question_pooled = run_one_step(
                batch, model, device)

            labels = batch["label"].to(device)

            # Calculate loss for the batch
            distance = cal_distance(query_pooled, question_pooled, cos=False)
            # Calculate batch loss based on CrossEntropy
            loss_fn = ContrastiveLoss(margin=1)
            loss = loss_fn(distance, labels)

            # Apply softmax to the predicted logits
            # This converts the "logits" to "probability-like" scores
            pred_label = [1 if d > config.THRESHOLDE else 0 for d in distance]
            labels = labels.cpu().numpy().tolist()
            pred_labels.extend(pred_label)
            true_labels.extend(labels)
            acc, f1 = calculate_metrics_score(
                label=labels,
                pred_label=pred_label,
            )
            # Print the running average loss and acc and f1 score
            tk0.set_postfix(loss=loss.item(), acc=acc, f1=f1)

    acc, f1, auc = calculate_metrics_score(label=true_labels,
                                           pred_label=pred_labels,
                                           cal_auc=True)
    logger.info(f"acc = {acc}, f1 = {f1}, auc={auc}")
    return acc, f1, auc
コード例 #6
0
def train_fn(data_loader,
             model,
             optimizer,
             device,
             scheduler=None,
             threshold=None):
    """
    Trains the bert model on the twitter data
    """
    # Set model to training mode (dropout + sampled batch norm is activated)
    model.train()

    # Set tqdm to add loading screen and set the length
    tk0 = tqdm(data_loader, total=len(data_loader))
    # Train the model on each batch
    for bi, batch in enumerate(tk0):

        query_sequence, question_sequence, query_pooled, question_pooled = run_one_step(
            batch, model, device)
        labels = batch["label"].to(device)
        distance = cal_distance(query_pooled, question_pooled, cos=False)
        # Calculate batch loss based on CrossEntropy
        loss_fn = ContrastiveLoss(margin=1)
        loss = loss_fn(distance, labels)
        # Calculate gradients based on loss
        loss.backward()
        # Adjust weights based on calculated gradients
        optimizer.step()
        # Update scheduler
        scheduler.step()

        pred_labels = [1 if d > threshold else 0 for d in distance]
        # Calculate the jaccard score based on the predictions for this batch
        acc, f1 = calculate_metrics_score(
            label=labels.cpu().numpy(),
            pred_label=np.array(pred_labels),
        )
        # Print the average loss and jaccard score at the end of each batch
        tk0.set_postfix(loss=loss.item(), acc=acc, f1=f1)
コード例 #7
0
    def train(self,):
        LOGGER.info('\n---------------- Train Starting ----------------')

        # Load training/validation data
        LOGGER.info('Load training/validation data')
        LOGGER.info('------------------------------')


        train_dataset=self.load_dataset("train")
        val_dataset=self.load_dataset("val")

        # LOGGER.info("start build catch repr for train")




        # n_hidden = 128
        # self.model = NN(768, n_hidden, 768).to(self.device)


        criterion = ContrastiveLoss()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)

        n_epcoh = self.exp_config.epochs
        print_every = 5000
        plot_every_n_batch = self.exp_config.plot_every_n_batch

        batch_size = self.exp_config.batch_size
        import random


        # Keep track of losses for plotting
        current_loss = 0
        all_losses = []
        batch_x=[]
        batch_i=1

        sentence_tensor = []
        catch_tensor = []

        self.model.train()
        # w = self.model.nn1.weight.data.clone()
        for epoch in range(n_epcoh):
            LOGGER.info("epoch {} starts".format(epoch))
            for i, case in enumerate(tqdm(train_dataset["all_cases"][:self.exp_config.iter_per_epoch])):
                text_idx = self.model.tokenizer(train_dataset["case_sentences"][case], truncation=True, return_tensors="pt",
                                     padding='max_length', max_length=512).to(self.device)
                last_hidden_state, pooler_output = self.model.encoder(**text_idx)
                sentence_tensor.append(pooler_output)

                catchphrase_id = randomChoice(train_dataset["case_catchphrases"][case])
                catchphrase = train_dataset["idx_catchphrases"][catchphrase_id]
                text_idx = self.model.tokenizer(catchphrase, truncation=True, return_tensors="pt",
                                     padding='max_length', max_length=18).to(self.device)
                last_hidden_state, pooler_output = self.model.encoder(**text_idx)
                catch_tensor.append(pooler_output)

                # Print iter number, loss, name and guess
                if i % batch_size == 0:
                    sentence_tensor = torch.cat(sentence_tensor, dim=0).to(self.device)
                    catch_tensor = torch.cat(catch_tensor, dim=0).to(self.device)

                    batch_loss = self.train_step(sentence_tensor, catch_tensor, criterion, optimizer)
                    LOGGER.info("Record model.nn1.weight.data[0][:10]:")
                    LOGGER.info(self.model.nn1.weight.data[0][:10])
                    batch_x.append(batch_i)
                    batch_i+=1
                    LOGGER.info("loss = "+str(batch_loss/batch_size))
                    # current_loss += batch_loss
                    sentence_tensor = []
                    catch_tensor = []

                # # Add current loss avg to list of losses
                # if i % plot_every_n_batch * batch_size == 0:
                    all_losses.append(batch_loss/batch_size)
                    # current_loss = 0
                    self.plot_loss(batch_x,all_losses)

            self.evaluate()

            torch.save({
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': all_losses[-1],
            }, os.path.join(self.exp_config.checkpoint_path, "model{}.pt".format(strftime("%Y_%m_%d_%H_%M_%S", gmtime()))))
            LOGGER.info("checkpoint saved")
コード例 #8
0
def main(netname, nepoch, lossname, opt, lr, dircheckpoint, dirdataset, device,
         envplotter):
    print(20 * "-")
    print("netname:", netname)
    print('n_epoch:', nepoch)
    print("loss:", lossname)
    print("opt:", opt)
    print("lr:", lr)
    print("dircheckpoint:", dircheckpoint)
    print("dirDataset:", dirdataset)
    print("device:", device)
    print('envplotter:', envplotter)
    print(20 * "-")

    Config.train_number_epochs = nepoch

    networks = {
        "alexnet": SketchNetwork,
        "resnet": SketchNetworkResnet,
        "vgg": SketchNetworkVGG
    }

    net = networks[netname]()
    #net = nn.DataParallel(net)
    net = net.to(device)

    image_size = 224
    if netname == "inception":
        image_size = 299
        print(image_size)

    criterion_triplet = nn.TripletMarginLoss(margin=1.0)
    criterion_contrast = ContrastiveLoss()
    criterion_CE = nn.CrossEntropyLoss()

    losses = {
        "triplet":
        lambda x, y, z: criterion_triplet(x, y, z),
        "contrast":
        lambda x, y, z:
        (criterion_contrast(x, y,
                            torch.ones(Config.train_batch_size).to(device)) +
         criterion_contrast(
             x, z, -1 * torch.ones(Config.train_batch_size).to(device))),
        "crossEntropyLoss":
        lambda out0, out1: (criterion_CE(
            torch.cat((out0, out1)),
            torch.cat(
                (torch.zeros(Config.train_batch_size),
                 torch.ones(Config.train_batch_size))).to(device).long()))
    }

    criterion = losses[lossname]

    uses_triple = False
    if lossname == "triplet":
        uses_triple = True

    if opt != 'adm':
        optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)
    else:
        optimizer = optim.Adam(net.parameters(), lr=lr)

    plotter = VisdomLinePlotter(env_name=envplotter, port=8097)

    epoch_loss = 0
    valid_epoch_loss = 0
    best_loss = -1
    num_batch = 1
    num_batch_val = 1

    sh.mkdir("-p", dircheckpoint)
    files_checkpoints = np.array(
        sorted(glob.glob(dircheckpoint +
                         "/*{}_{}*".format(netname, lossname))))
    if (files_checkpoints.shape[0]):
        checkpoint = torch.load(files_checkpoints[-1])
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print(files_checkpoints[-1])

    transf = transforms.Compose([
        transforms.RandomRotation((-45, 45), fill=(255, 255, 255, 1)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor()
    ])

    #DATA LOADERS
    sketch_dataset = SketchZoomDataset(data_sketch_root="data/",
                                       net=net,
                                       plotter=None,
                                       n=Config.train_data_n,
                                       stage="train",
                                       image_size=image_size,
                                       triplet=uses_triple,
                                       categories=[
                                           "Airplane", "Bag", "Cap", "Car",
                                           "Chair", "Earphone", "Guitar",
                                           "Knife", "Lamp", "Laptop",
                                           "Motorbike", "Mug", "Pistol",
                                           "Rocket", "Skateboard", "Table"
                                       ],
                                       transform=transf,
                                       device=device)

    train_dataloader = DataLoader(sketch_dataset,
                                  shuffle=True,
                                  num_workers=0,
                                  batch_size=Config.train_batch_size,
                                  drop_last=True)

    sketch_dataset_test = SketchZoomDataset(
        data_sketch_root="data/",
        net=net,
        plotter=None,  #plotter, 
        n=Config.val_data_n,
        stage="test",
        image_size=image_size,
        triplet=uses_triple,
        categories=[
            "Airplane", "Bag", "Cap", "Car", "Chair", "Earphone", "Guitar",
            "Knife", "Lamp", "Laptop", "Motorbike", "Mug", "Pistol", "Rocket",
            "Skateboard", "Table"
        ],
        transform=transf,
        device=device)

    validation_dataloader = DataLoader(sketch_dataset_test,
                                       shuffle=True,
                                       num_workers=0,
                                       batch_size=Config.val_batch_size,
                                       drop_last=True)

    for epoch in range(1, Config.train_number_epochs):
        print('EPOCH', epoch)
        valid_epoch_loss = 0
        epoch_loss = 0
        net.train()

        #TRAIN
        for i, data in enumerate(train_dataloader):
            sketch_dataset.net = net
            net.train()

            img0, img1, img2 = data
            img0, img1, img2 = Variable(img0).to(device), Variable(img1).to(
                device), Variable(img2).to(device)
            optimizer.zero_grad()

            if lossname != 'crossEntropyLoss':
                output1, output2, output3 = net(img0, img1, img2,
                                                img1.size()[0])
                loss = criterion(output1, output2, output3)
            else:
                output1, output2, res0 = net.forward_two_binary(
                    img0, img1,
                    img1.size()[0])
                output1, output3, res1 = net.forward_two_binary(
                    img1, img2,
                    img2.size()[0])
                loss = criterion(res0, res1)

            distances_negativa = F.pairwise_distance(output1, output3)
            distances_negativa = distances_negativa.data.cpu().numpy().flatten(
            )
            distances_positiva = F.pairwise_distance(output1, output2)
            distances_positiva = distances_positiva.data.cpu().numpy().flatten(
            )

            loss.backward()
            epoch_loss += loss.item()
            optimizer.step()

            print('*' * 20)
            print(" nro Batch {} --  Current loss {}\n".format(i, loss.item()))
            num_batch = num_batch + 1
            plotter.plot('Distance mean', str(0), num_batch,
                         np.mean(distances_positiva), "Batchs")
            plotter.plot('Distance mean', str(1), num_batch,
                         np.mean(distances_negativa), "Batchs")
            plotter.plot('Batchs loss', str(epoch), i + 1, loss.item(),
                         "Batchs")
            if i != 0 and i % 9 == 0:
                torch.save(
                    {
                        'model_state_dict': net.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                    }, "{}/{}_{}_{}.pkl".format(dircheckpoint,
                                                "current_batch_checkpoints",
                                                netname, lossname))
                print("save net, checkpoint Batch")

        del i, data, distances_negativa, distances_positiva, output1, output2, output3, img0, img1, img2

        net.eval()
        with torch.no_grad():

            sketch_dataset.net = net

            for i, data in enumerate(validation_dataloader):

                img0, img1, img2 = data
                img0, img1, img2 = Variable(img0).to(device), Variable(
                    img1).to(device), Variable(img2).to(device)

                if lossname != 'crossEntropyLoss':
                    output1, output2, output3 = net(img0, img1, img2,
                                                    img1.size()[0])
                    loss = criterion(output1, output2, output3)
                else:
                    output1, output2, res0 = net.forward_two_binary(
                        img0, img1,
                        img1.size()[0])
                    output1, output3, res1 = net.forward_two_binary(
                        img0, img2,
                        img2.size()[0])
                    loss = criterion(res0, res1)

                distances_negativa = F.pairwise_distance(output1, output3)
                distances_negativa = distances_negativa.data.cpu().numpy(
                ).flatten()
                distances_positiva = F.pairwise_distance(output1, output2)
                distances_positiva = distances_positiva.data.cpu().numpy(
                ).flatten()
                num_batch_val = num_batch_val + 1

                plotter.plot('Distance mean Valid', str(0), num_batch_val,
                             np.mean(distances_positiva[0]), "Batchs")

                plotter.plot('Distance mean Valid', str(1), num_batch_val,
                             np.mean(distances_negativa[0]), "Batchs")
                print(" nro Valid Batch{} --  Valid loss {}\n".format(
                    i + 1, loss.item()))
                valid_epoch_loss += loss.item()

            del i, data, distances_negativa, distances_positiva, output1, output2, output3, img0, img1, img2

        #END TRAIN
        current_epoch_loss = epoch_loss / (Config.train_data_n //
                                           Config.train_batch_size)
        current_epoch_loss_val = valid_epoch_loss / (Config.val_data_n //
                                                     Config.val_batch_size)

        print("Epoch number {}\n Current loss average {}\n".format(
            epoch, current_epoch_loss))
        print("Epoch number {}\n Current loss val average {}\n".format(
            epoch, current_epoch_loss_val))
        plotter.plot('Epochs loss ', 'train epoch', epoch, current_epoch_loss,
                     "Epochs")
        plotter.plot('Epochs loss ', 'valid epoch', epoch,
                     current_epoch_loss_val, "Epochs")

        #SAVE NET WITH BEST LOSS
        if (best_loss == -1 or current_epoch_loss < best_loss):
            torch.save(
                {
                    'model_state_dict': net.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }, "{}/{}_{}_{}.pkl".format(dircheckpoint, "checkpoints",
                                            netname, lossname))

            best_loss = current_epoch_loss
            print("save net, loss: {}".format(best_loss))
コード例 #9
0
                                batch_size=args.test_batch)
    

    nfeat = train_loader.__iter__().__next__()['input_anchor']['x'].shape[1]
    print("NFEAT: ",nfeat)
    print("Model: ",args.model)
    print("Scheduler: On") if args.no_scheduler else print("Scheduler: Off")
    if not args.no_windowed and args.input_type=='RST': print("Window: On")


    elif args.model == 'gcn_cheby':
        model = Siamese_GeoChebyConv(nfeat=nfeat,
                                     nhid=args.hidden,
                                     nclass=1,
                                     dropout=args.dropout)
        criterion = ContrastiveLoss(args.loss_margin)

    elif args.model == 'gcn_cheby_bce':
        model = Siamese_GeoChebyConv_Read(nfeat=nfeat,
                                     nhid=args.hidden,
                                     nclass=1,
                                     dropout=args.dropout)
        criterion = BCEWithLogitsLoss()

    elif args.model == 'gcn_cheby_cos':
        model = Siamese_GeoCheby_Cos(nfeat=nfeat,
                                     nhid=args.hidden,
                                     nclass=1,
                                     dropout=args.dropout)
        criterion = ContrastiveCosineLoss(args.temperature).to(device)
コード例 #10
0
ファイル: train_xh.py プロジェクト: layumi/NLP-AICity2021
y_loss = {}  # loss history
y_loss['train'] = []
y_loss['val'] = []
y_err = {}
y_err['train'] = []
y_err['val'] = []


def l2_norm(v):
    fnorm = torch.norm(v, p=2, dim=1, keepdim=True) + 1e-6
    v = v.div(fnorm.expand_as(v))
    return v


xhloss = ContrastiveLoss()


def compute_loss(model, input_ids, attention_mask, crop, motion, nl_id,
                 crop_id, label, warm):
    if opt.motion:
        visual_embeds, lang_embeds, predict_class_v, predict_class_l, predict_class_motion = model.forward(
            input_ids, attention_mask, crop, motion.cuda())
    else:
        visual_embeds, lang_embeds, predict_class_v, predict_class_l = model.forward(
            input_ids, attention_mask, crop)
    #print(similarity.shape, predict_class_v.shape, predict_class_l.shape)
    #print(label.shape, nl_id.shape)
    #label = label.float()

    visual_embeds = l2_norm(visual_embeds)