Exemplo n.º 1
0
    def build_model(self):
        # placeholder
        self.x_images = tf.placeholder(
            tf.float32,
            [None, self.image_size, self.image_size, self.image_channel],
            name='image')
        self.y_labels = tf.placeholder(tf.int64, [
            None,
        ], name='label')

        self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
        #self.is_train	= tf.placeholder(tf.bool, name='is_train')

        self.embeddings = CNN(self.x_images, self.keep_prob, self.n_embeddings)

        # triplet loss
        self.loss, _ = batch_all_triplet_loss(self.y_labels,
                                              self.embeddings,
                                              margin=self.margin)

        # optimizer
        self.train_op = tf.train.AdamOptimizer(
            learning_rate=self.learning_rate).minimize(self.loss)

        # accuracy
        self.true_label = tf.placeholder(tf.int64, [
            None,
        ], name='true_label')
        self.pred_label = tf.placeholder(tf.int64, [
            None,
        ], name='pred_label')
        self.accu = tf.reduce_mean(
            tf.cast(tf.equal(self.true_label, self.pred_label), tf.float32))
def test_simple_batch_all_triplet_loss():
    """Test the triplet loss with batch all triplet mining in a simple case.
    There is just one class in this super simple edge case, and we want to make sure that
    the loss is 0.
    """
    num_data = 10
    feat_dim = 6
    margin = 0.2
    num_classes = 1

    embeddings = np.random.rand(num_data, feat_dim).astype(np.float32)
    labels = np.random.randint(0, num_classes,
                               size=(num_data)).astype(np.float32)
    labels, embeddings = torch.from_numpy(labels), torch.from_numpy(embeddings)

    for squared in [True, False]:
        loss_np = 0.0

        # Compute the loss in TF.
        loss_tf_val, fraction_val = batch_all_triplet_loss(labels,
                                                           embeddings,
                                                           margin,
                                                           squared=squared)

        assert np.allclose(loss_np, loss_tf_val)
        assert np.allclose(fraction_val, 0.0)
Exemplo n.º 3
0
def triplet_loss(y_true, y_pred):
    #one_hot_label = keras.utils.to_categorical(y_true[:,0,0], num_classes = 10988)
    # softmax_ce = tf.nn.softmax_cross_entropy_with_logits(labels=one_hot_label, logits=y_pred)
    label = K.flatten(y_true[:, 0, 0])
    # softmax = tf.reduce_mean(softmax_ce)
    triplet_loss, fraction_positive_triplets, mask = batch_all_triplet_loss(
        label, y_pred[:, 1], y_pred[:, 0], 0.2)
    print(fraction_positive_triplets)
    return triplet_loss
def train_model_with_tripletloss(model,
                                 train_data,
                                 train_labels,
                                 test_data,
                                 test_labels,
                                 num_classes,
                                 len_encoding,
                                 num_epochs=20,
                                 batch_size=128,
                                 learning_rate=0.001,
                                 margin=0.5,
                                 triplet_loss_strategy="batch_hard"):

    # Generate tf.data.Dataset
    train_dataset = tf.data.Dataset.from_tensor_slices(
        (train_data, train_labels))
    train_dataset = train_dataset.batch(batch_size)

    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    # Placeholder for total loss over an epoch
    overall_triplet_loss = tf.Variable(0, dtype=tf.float32)

    # Train network
    for epoch in range(num_epochs):
        # Iterate over minibatches
        overall_triplet_loss.assign(0.0)
        for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
            triplet_loss_value = train_one_step_tripletloss(
                model, x_batch_train, y_batch_train, triplet_loss_strategy,
                margin, optimizer)

            overall_triplet_loss.assign_add(triplet_loss_value)
            #if(epoch==0):
            #    print('Step: {} Loss: {}'.format(step, triplet_loss_value))
        print('Epoch: {}: Average Train Loss: {} Last Batch Loss {}'.format(
            epoch,
            overall_triplet_loss.numpy() / batch_size, triplet_loss_value))

        # Model evaluation on test set for this epoch
        test_dataset = tf.data.Dataset.from_tensor_slices(
            (test_data, test_labels))
        test_dataset = test_dataset.batch(batch_size)
        overall_triplet_loss.assign(0.0)
        for x_batch_test, y_batch_test in test_dataset:
            embeddings = model(x_batch_test)
            if (triplet_loss_strategy == "batch_hard"):
                triplet_loss_value = triplet_loss.batch_hard_triplet_loss(
                    y_batch_test, embeddings, margin)
            elif (triplet_loss_strategy == "batch_all"):
                triplet_loss_value = triplet_loss.batch_all_triplet_loss(
                    y_batch_test, embeddings, margin)
            overall_triplet_loss.assign_add(triplet_loss_value)
        print('Epoch: {}:  Average Test Loss: {} Last Batch Loss {}'.format(
            epoch,
            overall_triplet_loss.numpy() / batch_size, triplet_loss_value))
    def batch_all_triplet_loss(self,
                               embeddings,
                               labels,
                               margin=0.5,
                               squared=False):

        return triplet_loss.batch_all_triplet_loss(labels,
                                                   embeddings,
                                                   margin,
                                                   squared=squared)
Exemplo n.º 6
0
def build_network(input_images, labels, ratio=0.5):  
    logits, embedding = inference(input_images)  
      
    with tf.name_scope('loss'):  
        with tf.name_scope('triplet_loss'):  
            total_loss, fraction = batch_all_triplet_loss(labels, embedding, margin=0.5, squared=False)
      
    with tf.name_scope('loss/'): 
        tf.summary.scalar('TotalLoss', total_loss)  
          
    return total_loss, fraction, logits, embedding
def train_one_step_tripletloss(model, x_batch_train, y_batch_train,
                               triplet_loss_strategy, margin, optimizer):
    with tf.GradientTape(persistent=False) as tape:
        embeddings = model(x_batch_train)
        if (triplet_loss_strategy == "batch_hard"):
            triplet_loss_value = triplet_loss.batch_hard_triplet_loss(
                y_batch_train, embeddings, margin)
        elif (triplet_loss_strategy == "batch_all"):
            triplet_loss_value = triplet_loss.batch_all_triplet_loss(
                y_batch_train, embeddings, margin)

    grads = tape.gradient(triplet_loss_value, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

    return triplet_loss_value
Exemplo n.º 8
0
def finetune_on_support(model,
                        Dataloader,
                        orig_target_dict,
                        optimizer,
                        epochs=6,
                        logger=None,
                        margin=0.1,
                        Lambda=10):
    MSE = nn.MSELoss()
    model.train()

    for epoch in range(epochs):
        print("\n \n" + ("-" * 20))
        print("Epoch : ", epoch)
        print("-" * 20)
        for batch_id, data in enumerate(Dataloader):
            imgs = data[0].cuda()
            labels = data[1].cuda()
            idxs = data[2]
            gt_label = data[3]

            optimizer.zero_grad()

            target_old_task = torch.Tensor(orig_target_dict[idxs, :]).cuda()
            orig_feats, new_feats = model(imgs)

            L_old = MSE(orig_feats, target_old_task)
            L_new, frac_valid_trip = batch_all_triplet_loss(labels,
                                                            new_feats,
                                                            margin=margin,
                                                            squared=False)
            L_total = Lambda * L_old + L_new
            # L_total = L_new
            sys.stdout.write(
                "\rLoss Total : %f L_trip : %f L_old : %f, Frac Valid Trips: %f"
                % (L_total.item(), L_new.item(), L_old.item(),
                   frac_valid_trip.item()))
            sys.stdout.flush()

            if logger is not None:
                logger.add_scalar("loss/iter", L_total.item(),
                                  epoch * len(Dataloader) + batch_id)

            L_total.backward()
            optimizer.step()
    model.eval()
    return L_old, L_new
def test_batch_all_triplet_loss():
    """Test the triplet loss with batch all triplet mining"""
    num_data = 10
    feat_dim = 6
    margin = 0.2
    num_classes = 5

    embeddings = np.random.rand(num_data, feat_dim).astype(np.float32)
    labels = np.random.randint(0, num_classes,
                               size=(num_data)).astype(np.float32)

    for squared in [True, False]:
        pdist_matrix = pairwise_distance_np(embeddings, squared=squared)

        loss_np = 0.0
        num_positives = 0.0
        num_valid = 0.0
        for i in range(num_data):
            for j in range(num_data):
                for k in range(num_data):
                    distinct = (i != j and i != k and j != k)
                    valid = (labels[i]
                             == labels[j]) and (labels[i] != labels[k])
                    if distinct and valid:
                        num_valid += 1.0

                        pos_distance = pdist_matrix[i][j]
                        neg_distance = pdist_matrix[i][k]

                        loss = np.maximum(0.0,
                                          pos_distance - neg_distance + margin)
                        loss_np += loss

                        num_positives += (loss > 0)

        loss_np /= num_positives

        # Compute the loss in TF.
        loss_tf, fraction = batch_all_triplet_loss(labels,
                                                   embeddings,
                                                   margin,
                                                   squared=squared)
        with tf.Session() as sess:
            loss_tf_val, fraction_val = sess.run([loss_tf, fraction])
        assert np.allclose(loss_np, loss_tf_val)
        assert np.allclose(num_positives / num_valid, fraction_val)
Exemplo n.º 10
0
def cnn_model(features, labels, mode):
    images = features['images']
    filenames = features['filenames']
    onehot_labels = labels
    axillary_labels = features['axillary_labels']

    if FLAGS.network == 'alexnet':
        # Format data
        if FLAGS.data_format == 'NCHW':
            print(colored("Converting data format to channels first (NCHW)", \
                    'blue'))
            images = tf.transpose(images, [0, 3, 1, 2])

        # Setup batch normalization
        if mode == tf.estimator.ModeKeys.TRAIN:
            norm_params={'is_training':True, 
                    'data_format': FLAGS.data_format}
        else:
            norm_params={'is_training':False,
                    'data_format': FLAGS.data_format,
                    'updates_collections': None}

        # Create the network
        logits = alexnet(images, norm_params, mode) 

    elif FLAGS.network == 'resnet':
        logits, end_points = resnet_v2.resnet_v2_50(inputs=images, 
                num_classes=ts._NUM_CLASSES, 
                is_training=(mode==tf.estimator.ModeKeys.TRAIN))

    # Inference
    predicted_classes = tf.argmax(logits, 1)
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(
                mode,
                predictions={
                    'pred_class': predicted_classes,
                    'gt_class': axillary_labels,
                    'embedding': logits,
                    # 'prob': tf.nn.softmax(logits),
                })

    # Training 
    groundtruth_classes = tf.argmax(onehot_labels, 1)
    if FLAGS.mode == "triplet_training":
        if FLAGS.triplet_mining_method == "batchall":
            loss, fraction_positive_triplets, num_valid_triplets = \
                    triplet_loss.batch_all_triplet_loss(
                    axillary_labels, logits, FLAGS.triplet_margin)
        elif FLAGS.triplet_mining_method == "batchhard":
            loss = triplet_loss.batch_hard_triplet_loss(
                    axillary_labels, logits, FLAGS.triplet_margin)
        else:
            "ERROR: Wrong Triplet loss mining method, using softmax"
            loss = tf.losses.softmax_cross_entropy(
                    onehot_labels=onehot_labels, logits=logits)
        if FLAGS.loss_mode == "mix":
            loss += tf.losses.softmax_cross_entropy(
                    onehot_labels=onehot_labels, logits=logits)

    else:
        loss = tf.losses.softmax_cross_entropy(
                onehot_labels=onehot_labels, logits=logits)

    if mode == tf.estimator.ModeKeys.TRAIN:
        if FLAGS.optimizer == 'GD':
            decay_factor = 0.96
            learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                    tf.train.get_global_step(),
                    int(math.ceil(float(ts._SPLITS_TO_SIZES['train'] / 
                        FLAGS.batch_size))),
                    decay_factor)
            optimizer = tf.train.GradientDescentOptimizer(
                    learning_rate=learning_rate)
        elif FLAGS.optimizer == 'Momentum':
            decay_factor = 0.96
            learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                    tf.train.get_global_step(),
                    int(math.ceil(float(ts._SPLITS_TO_SIZES['train'] / 
                        FLAGS.batch_size))),
                    decay_factor)
            optimizer = tf.train.MomentumOptimizer(
                    learning_rate=learning_rate, momentum=0.9)
        else:
            optimizer = tf.train.AdamOptimizer(
                    learning_rate=FLAGS.learning_rate)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, 
                    global_step=tf.train.get_global_step())
            return tf.estimator.EstimatorSpec(
                    mode, loss=loss, train_op=train_op)

    # Testing
    # top_5 = tf.metrics.precision_at_top_k(
            # labels=groundtruth_classes, 
            # predictions=predicted_classes,
            # k = 5)
    # top_10 = tf.metrics.precision_at_top_k(
            # labels=groundtruth_classes, 
            # predictions=predicted_classes,
            # k = 10)
    eval_metric_ops = {
            'eval/accuracy': tf.metrics.accuracy(
                labels=groundtruth_classes, 
                predictions=predicted_classes),
            # 'eval/accuracy_top5': top_5,
            # 'eval/accuracy_top10': top_10,
            }
    return tf.estimator.EstimatorSpec(
            mode, loss=loss, eval_metric_ops=eval_metric_ops)
Exemplo n.º 11
0
def my_model(features, labels, mode, params):
    '''
       model_fn指定函数,构建模型,训练等
       ---------------------------------
       Args:
          features: 输入,shape = (batch_size, 784)
          labels:   输出,shape = (batch_size, )
          mode:     str, 阶段
          params:   dict, 超参数
    '''
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    images = features
    images = tf.reshape(
        images, shape=[-1, params['image_size'], params['image_size'],
                       1])  # reshape (batch_size, img_size, img_size, 1)
    with tf.variable_scope("model"):
        embeddings = build_model(is_training, images, params)  # 简历模型

    if mode == tf.estimator.ModeKeys.PREDICT:  # 如果是预测阶段,直接返回得到embeddings
        predictions = {'embeddings': embeddings}
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
    '''调用对应的triplet loss'''
    labels = tf.cast(labels, tf.int64)
    if params['triplet_strategy'] == 'batch_all':
        loss, fraction = batch_all_triplet_loss(labels,
                                                embeddings,
                                                margin=params['margin'],
                                                squared=params['squared'])
    elif params['triplet_strategy'] == 'batch_hard':
        loss = batch_hard_triplet_loss(labels,
                                       embeddings,
                                       margin=params['margin'],
                                       squared=params['squared'])
    else:
        raise ValueError("triplet_strategy 配置不正确: {}".format(
            params['triplet_strategy']))

    embedding_mean_norm = tf.reduce_mean(tf.norm(
        embeddings, axis=1))  # 这里计算了embeddings的二范数的均值
    tf.summary.scalar("embedding_mean_norm", embedding_mean_norm)
    with tf.variable_scope("metrics"):
        eval_metric_ops = {
            'embedding_mean_norm': tf.metrics.mean(embedding_mean_norm)
        }
        if params['triplet_strategy'] == 'batch_all':
            eval_metric_ops['fraction_positive_triplets'] = tf.metrics.mean(
                fraction)
    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode,
                                          loss=loss,
                                          eval_metric_ops=eval_metric_ops)

    tf.summary.scalar('loss', loss)
    if params['triplet_strategy'] == "batch_all":
        tf.summary.scalar('fraction_positive_triplets', fraction)
    tf.summary.image('train_image', images, max_outputs=1)  # 1代表1个channel

    optimizer = tf.train.AdamOptimizer(learning_rate=params['learning_rate'])
    global_step = tf.train.get_global_step()
    if params['use_batch_norm']:
        '''如果使用BN,需要估计batch上的均值和方差,tf.get_collection(tf.GraphKeys.UPDATE_OPS)就可以得到
        tf.control_dependencies计算完之后再进行里面的操作
        '''
        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            train_op = optimizer.minimize(loss, global_step=global_step)
    else:
        train_op = optimizer.minimize(loss, global_step=global_step)
    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
Exemplo n.º 12
0
    def _build_graph(self, input_vars):
        image, label = input_vars
        image = image / 128.0 - 1

        def conv(name, l, channel, stride):
            return Conv2D(name,
                          l,
                          channel,
                          3,
                          stride=stride,
                          nl=tf.identity,
                          use_bias=False,
                          W_init=tf.random_normal_initializer(
                              stddev=np.sqrt(2.0 / 9 / channel)))

        def add_layer(name, l):
            in_shape = l.get_shape().as_list()
            in_channel = in_shape[3]
            with tf.variable_scope(name) as scope:
                c = BatchNorm('bn1', l)
                c = tf.nn.relu(c)
                c = conv('conv1', c, self.growthRate, 1)
                l = tf.concat([c, l], 3)
            return l

        def add_transition(name, l):
            shape = l.get_shape().as_list()
            in_channel = shape[3]
            with tf.variable_scope(name) as scope:
                l = BatchNorm('bn1', l)
                l = tf.nn.relu(l)
                l = Conv2D('conv1',
                           l,
                           in_channel,
                           1,
                           stride=1,
                           use_bias=False,
                           nl=tf.nn.relu)
                l = AvgPooling('pool', l, 2)
                #l = MaxPooling('pool', l, 2)
            return l

        def dense_net(name):
            l = conv('conv0', image, self.growthRate, 1)
            with tf.variable_scope('block1') as scope:

                for i in range(self.N):
                    l = add_layer('dense_layer.{}'.format(i), l)
                l = add_transition('transition1', l)

            with tf.variable_scope('block2') as scope:

                for i in range(self.N):
                    l = add_layer('dense_layer.{}'.format(i), l)
                l = add_transition('transition2', l)

            with tf.variable_scope('block3') as scope:

                for i in range(self.N):
                    l = add_layer('dense_layer.{}'.format(i), l)

            l = BatchNorm('bnlast', l)
            l = tf.nn.relu(l)
            l = GlobalAvgPooling('gap', l)

            logits = FullyConnected('linear', l, out_dim=10, nl=tf.identity)

            return logits, l

        logits, embedings = dense_net("dense_net")

        prob = tf.nn.softmax(logits, name='output')

        #cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
        #cost = tf.reduce_mean(cost, name='cross_entropy_loss')
        #triplet_cost, fraction = batch_all_triplet_loss(label, embedings, margin=0.5, squared=False)
        cost, fraction = batch_all_triplet_loss(label,
                                                embedings,
                                                margin=0.5,
                                                squared=True)

        wrong = prediction_incorrect(logits, label)
        # monitor training error
        add_moving_summary(tf.reduce_mean(wrong, name='train_error'))

        # weight decay on all W
        wd_cost = tf.multiply(1e-4,
                              regularize_cost('.*/W', tf.nn.l2_loss),
                              name='wd_cost')
        #add_moving_summary(cost, triplet_cost, wd_cost)
        add_moving_summary(cost, wd_cost)
        add_moving_summary(fraction)

        add_param_summary(('.*/W', ['histogram']))  # monitor W
        #self.cost = tf.add_n([cost, triplet_cost, wd_cost], name='cost')
        self.cost = tf.add_n([cost, wd_cost], name='cost')
Exemplo n.º 13
0
def Fine_tuning(args, train_episodes_idx, graphs_list, Dagcn, Gin, Graph_reps_gcn, FTL_Classifier, \
                Dagcn_optimizer, Gin_optimizer, gcn_optimizer, FTL_optimizer, ft_epoch, ft_criterion, device, test_episodes_idx=None):

    if args.JS_div:
        JS_lamda = adjust_lamda(args, ft_epoch)

    ft_loss_accum = 0.0
    Correct = 0.0
    count_samples = 0.0
    
    if not args.Fixed_weights:
        for param_group in Dagcn_optimizer.param_groups:
            print('Dagcn_optimizer lr:', param_group['lr'])
        for param_group in Gin_optimizer.param_groups:
            print('Gin_optimizer lr:', param_group['lr'])
        for param_group in gcn_optimizer.param_groups:
            print('gcn_optimizer lr:', param_group['lr'])
    for param_group in FTL_optimizer.param_groups:
        print('FTL_optimizer lr:', param_group['lr'])
    
    for train_episode_idx in train_episodes_idx:
        count_samples += float(len(train_episode_idx))

        if Dagcn_optimizer is not None:
            Dagcn_optimizer.zero_grad()
            Gin_optimizer.zero_grad()
            gcn_optimizer.zero_grad()
        FTL_optimizer.zero_grad()

        train_episode = [graphs_list[idx] for idx in train_episode_idx]
        
        Gin_relation_mat, Gin_embed = Gin(train_episode)  #######
        Dagcn_relation_mat, Dagcn_embed = Dagcn(train_episode) ########
            
        graph_representations_1 = Graph_reps_gcn(Gin_embed, Gin_relation_mat)
        graph_representations_2 = Graph_reps_gcn(Dagcn_embed, Dagcn_relation_mat)

        Z_1 = FTL_Classifier(graph_representations_1)
        Z_2 = FTL_Classifier(graph_representations_2)

        query_labels = torch.LongTensor([graphs_list[idx].label for idx in train_episode_idx]).to(device)
    
        Z = Z_1 + Z_2
        loss = ft_criterion(Z, query_labels)

        pred = Z.max(1, keepdim=True)[1]
        correct = pred.eq(query_labels.view_as(pred)).sum().cpu().item()
        Correct += correct
        
        if not args.Fixed_weights:
            # triplet_l_1 = batch_hard_triplet_loss(query_labels, Dagcn_embed, device, margin=args.triplet_m)
            triplet_l_1, _ = batch_all_triplet_loss(query_labels, Dagcn_embed, device, margin=args.triplet_m)
            # triplet_l_2 = batch_hard_triplet_loss(query_labels, Gin_embed, device, margin=args.triplet_m)
            triplet_l_2, _ = batch_all_triplet_loss(query_labels, Gin_embed, device, margin=args.triplet_m)
                
            loss = args.CE_lambda * loss + (triplet_l_1 + triplet_l_2)

        if args.JS_div:
            JS_loss = Jensen_Shannon_div(float(len(train_episode_idx)), Z_1, Z_2)
            loss = loss + JS_lamda * JS_loss
            
        loss.backward()


        if Dagcn_optimizer is not None:
            # nn.utils.clip_grad_norm(Dagcn.parameters(), args.clip)
            nn.utils.clip_grad_norm(Gin.parameters(), args.clip)
            nn.utils.clip_grad_norm(Graph_reps_gcn.parameters(), args.clip)  
        nn.utils.clip_grad_norm(FTL_Classifier.parameters(), args.clip)
    
        if Dagcn_optimizer is not None:
            Dagcn_optimizer.step()
            Gin_optimizer.step()
            gcn_optimizer.step()
        FTL_optimizer.step()
    
        loss = loss.detach().cpu().numpy()
        ft_loss_accum += loss
        
    
    average_loss = ft_loss_accum/len(train_episodes_idx)    
    acc = Correct/count_samples

    print("Fine_tuning loss: ", average_loss, "accuracy:", acc)
    
    test_acc = 0.0
    if test_episodes_idx != None:
        test_acc = test(args, test_episodes_idx, graphs_list, Dagcn, Gin, Graph_reps_gcn, FTL_Classifier, ft_epoch, device)

    return test_acc
Exemplo n.º 14
0
def meta_train(args, train_episodes_idx, test_episodes_idx, graphs_list, Dagcn, Gin, Graph_reps_gcn, cls_head, Dagcn_optimizer, Gin_optimizer, gcn_optimizer, \
               count_update_epoch, train_way, test_way, device, target_episodes_idx=None, target_graphs_list=None):

    loss_accum = 0.0

    num_samples = len(train_episodes_idx[0])
    num_support_samples = train_way * args.train_shot
    num_query_samples = float(num_samples) - num_support_samples
 
    ACC = []

    for train_episode_idx in train_episodes_idx:
        Dagcn_optimizer.zero_grad()
        Gin_optimizer.zero_grad()
        gcn_optimizer.zero_grad()

        train_episode = [graphs_list[idx] for idx in train_episode_idx]
        
        # train_episode_label = [g.label for g in train_episode]
        # print('train_episode_label:', train_episode_label)
        
        Gin_relation_mat, Gin_embed = Gin(train_episode)  #######
        Dagcn_relation_mat, Dagcn_embed = Dagcn(train_episode) ########
            
        graph_representations_1 = Graph_reps_gcn(Gin_embed, Gin_relation_mat)
        graph_representations_2 = Graph_reps_gcn(Dagcn_embed, Dagcn_relation_mat)

        # labels = torch.LongTensor([graph.label for graph in train_episode]).to(device)
        # query_labels = labels[train_episode_mask_inverse]
    
        new_labels = torch.arange(train_way).repeat(args.num_query + args.train_shot)
        # new_labels = new_labels.type(torch.cuda.LongTensor)
        new_labels = new_labels.type(torch.LongTensor).to(device)
        query_labels = new_labels[num_support_samples:]
        
        s_labels = new_labels[:num_support_samples]
        
        support_reps_1 = graph_representations_1[:num_support_samples]
        query_reps_1   = graph_representations_1[num_support_samples:]
        support_reps_2 = graph_representations_2[:num_support_samples]
        query_reps_2   = graph_representations_2[num_support_samples:]


        if args.com_support_triplet_l:
            # s_triplet_l_1    = batch_hard_triplet_loss(s_labels, support_reps_1, device, margin=args.triplet_m)
            s_triplet_l_1, _ = batch_all_triplet_loss(s_labels, support_reps_1, device, margin=args.triplet_m)
            # s_triplet_l_2    = batch_hard_triplet_loss(s_labels, support_reps_2, device, margin=args.triplet_m)
            s_triplet_l_2, _ = batch_all_triplet_loss(s_labels, support_reps_2, device, margin=args.triplet_m)

        query_output_1 = few_shot_classification(args, cls_head, support_reps_1, query_reps_1, \
                                                 s_labels, train_way, args.train_shot, device)
        query_output_2 = few_shot_classification(args, cls_head, support_reps_2, query_reps_2, \
                                                 s_labels, train_way, args.train_shot, device)
        

        query_output = query_output_1 + query_output_2
        loss = criterion(query_output, query_labels)

        # print('CE LOSS:', loss)
        # print('query_output:', query_output)
            
        query_pred = query_output.max(1, keepdim=True)[1]
        correct = query_pred.eq(query_labels.view_as(query_pred)).sum().cpu().item()
        acc = correct / float(num_query_samples)
        ACC.append(acc)
        
        # triplet_l_1 = batch_hard_triplet_loss(new_labels, Dagcn_embed, device, margin=args.triplet_m)
        triplet_l_1, _ = batch_all_triplet_loss(new_labels, Dagcn_embed, device, margin=args.triplet_m)
        # triplet_l_2 = batch_hard_triplet_loss(new_labels, Gin_embed, device, margin=args.triplet_m)
        triplet_l_2, _ = batch_all_triplet_loss(new_labels, Gin_embed, device, margin=args.triplet_m)
        
        if args.com_support_triplet_l:
            loss = args.CE_lambda * loss + (triplet_l_1 + triplet_l_2 + s_triplet_l_1 + s_triplet_l_2)
        else:
            loss = args.CE_lambda * loss + (triplet_l_1 + triplet_l_2)

        if args.JS_div:
            JS_loss = Jensen_Shannon_div(num_query_samples, query_output_1, query_output_2)
            loss = loss + JS_lamda * JS_loss
            # print('JS_loss:', JS_loss)
            
        loss.backward()


        # nn.utils.clip_grad_norm(Dagcn.parameters(), args.clip)
        nn.utils.clip_grad_norm(Gin.parameters(), args.clip)
        nn.utils.clip_grad_norm(Graph_reps_gcn.parameters(), args.clip)

        Dagcn_optimizer.step()
        Gin_optimizer.step()
        gcn_optimizer.step()
    
        loss = loss.detach().cpu().numpy()
        loss_accum += loss


    average_loss = loss_accum/len(train_episodes_idx)

    print('count_update_epoch:', count_update_epoch, "train loss: ", average_loss)
    m, h = mean_confidence_interval(ACC)
    print('train', "accuracy: %f, h: %f" % (m, h))

    return Dagcn.state_dict(), Gin.state_dict(), Graph_reps_gcn.state_dict()
Exemplo n.º 15
0
 def call(self, x):
     triplet_loss = batch_all_triplet_loss(x)
     return triplet_loss
Exemplo n.º 16
0
 def loss(y_true, y_pred):
     loss, pos_trip = batch_all_triplet_loss(y_true, y_pred, margin)
     return loss
Exemplo n.º 17
0
def model_fn(features, labels, mode, params):
    """Model function for tf.estimator

    Args:
        features: (tf.Tensor) input batch of images
        labels: (tf.Tensor) labels of the images
        mode: (int) An instance of tf.estimator.ModeKeys (TRAIN, EVAL, PREDICT)
        params: (dict) experiment parameters

    Returns:
        model_spec: (tf.estimator.EstimatorSpec)
    """

    images = features
    tf.summary.image('train_image', images, max_outputs=3)

    # Create model
    embeddings = build_model(images, mode, params)

    # if predicting for new data, just compute and return the embeddings
    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {'embeddings': embeddings}
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # Define loss
    if params.loss == "triplet_batch_hard":
        loss = batch_hard_triplet_loss(labels,
                                       embeddings,
                                       margin=params.margin)
    elif params.loss == "triplet_batch_all":
        loss = batch_all_triplet_loss(labels,
                                      embeddings,
                                      margin=params.margin,
                                      metrics=metrics)
    elif params.loss == "cross_entropy":
        one_hot_labels = tf.one_hot(labels, params.num_classes)
        loss = tf.losses.softmax_cross_entropy(one_hot_labels, embeddings)
    tf.summary.scalar('loss', loss)

    # Metrics
    metrics = {}
    embedding_mean_norm = tf.metrics.mean(tf.norm(embeddings, axis=1))
    metrics["metrics/embedding_mean_norm"] = embedding_mean_norm
    with tf.name_scope("metrics/"):
        tf.summary.scalar('embedding_mean_norm', embedding_mean_norm[1])

    if params.loss == "cross_entropy":
        predictions = tf.argmax(embeddings, 1)
        accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions)
        metrics["metrics/accuracy"] = accuracy
        with tf.name_scope(
                "metrics/"
        ):  #slash / prevents scope from being unique and reenters the name scope
            tf.summary.scalar(
                'accuracy', accuracy[1]
            )  # The tf.summary.scalar will make accuracy available to TensorBoard
            # in both TRAIN and EVAL modes.

    if mode == tf.estimator.ModeKeys.EVAL:
        # loss and metrics run on validation set, https://www.tensorflow.org/guide/custom_estimators
        return tf.estimator.EstimatorSpec(
            mode, loss=loss, eval_metric_ops=metrics
        )  #Values of eval_metric_ops must be (metric_value, update_op) tuples

    # Optimizer
    optimizer = tf.train.AdamOptimizer(params.learning_rate)
    global_step = tf.train.get_global_step()

    train_op = optimizer.minimize(loss, global_step=global_step)

    return tf.estimator.EstimatorSpec(mode,
                                      loss=loss,
                                      train_op=train_op,
                                      training_hooks=[])
Exemplo n.º 18
0
 def triplet_loss(self, inp, targets):
     loss = triplet_loss.batch_all_triplet_loss(targets, inp, 1.0)
     return loss