def train(train_params, paths, model_params, data_params):
    epochs = train_params['epochs']
    batch_size = train_params['batch_size']
    val_size = train_params['val_size']
    learning_rate = train_params['learning_rate']

    dataset_root = paths['dataset']
    triplets_path = paths['triplets']
    output_root = paths['output_root']

    proposal_img_size = data_params['proposal_img_size']
    scene = data_params['scene']

    triplet_img_size = model_params['triplet_img_size']
    margin = model_params['margin']
    embedding_model = model_params['embedding_model']

    # OUTPUT FOLDER SETUP
    SLURM_SUFFIX = datetime.now().strftime("%d-%m-%Y_%H_%M_%S")

    try:
        SLURM_JOB_ID = str(os.environ["SLURM_JOB_ID"])
        SLURM_SUFFIX = SLURM_JOB_ID
    except KeyError:
        print('Slurm Job Id not avaialable')

    output_root = os.path.join(output_root, 'model_' + SLURM_SUFFIX)

    if not os.path.exists(output_root):
        os.makedirs(output_root)

    # DATASET SETUP
    dataset = ActiveVisionTriplet(dataset_root, triplets_path, instance=scene,
                                  image_size=proposal_img_size,
                                  triplet_image_size=triplet_img_size, get_labels=False,
                                  proposals_root=None, plot_original_proposals=False)
    num_val = round(len(dataset) * val_size)
    train_dataset, val_dataset = torch.utils.data.random_split(dataset,
                                                               [len(dataset) - num_val, num_val])

    train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_data_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

    # MODEL SETUP
    model = MetricLearningNet(model=embedding_model)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    losses = {'val_loss': [],
              'train_loss': []}

    with open(os.path.join(output_root, f'params.txt'), 'w') as f:
        _params = {
            'TRAIN_PARAMS': train_params,
            'PATHS': paths,
            'MODEL_PARAMS': model_params,
            'DATA_PARAMS': data_params
        }

        json.dump(_params, f)
        del _params

    with open(os.path.join(output_root, 'time.txt'), 'a') as f:
        f.write("START TIME: " + datetime.now().strftime("%m-%d-%Y_%H:%M:%S") + '\n')

    # TRAINING
    with tqdm(total=epochs) as epoch_bar:
        epoch_bar.set_description(f'Total Epoch Progress: '.ljust(20))

        for epoch in range(epochs):
            loss_per_iter = []

            with tqdm(total=len(train_data_loader)) as it_bar:
                it_bar.set_description(f'Epoch={epoch + 1}'.ljust(20))
                # start = time()

                for idx, (ref_list, pos_list, neg_list, labels) in enumerate(train_data_loader):
                    optimizer.zero_grad()
                    # print(f'Loading: {time()-start}')
                    # start = time()
                    ref_emb, pos_emb, neg_emb = model(ref_list, pos_list, neg_list)
                    # print(f'Forward: {time()-start}')

                    # start = time()
                    loss = triplet_loss(ref_emb, pos_emb, neg_emb, min_dist_neg=margin)
                    # print(f'Loss calc: {time()-start}')
                    loss_per_iter.append(loss.item())

                    # start = time()
                    loss.backward()
                    # print(f'Loss back: {time()-start}')

                    # start = time()
                    optimizer.step()
                    # print(f'Optimizer step: {time()-start}')

                    it_bar.set_postfix(train_loss=f'{loss:.4f}')
                    it_bar.update()

                    # start = time()

                losses['train_loss'].append(loss_per_iter)

                # VALIDATION of Epoch
                with torch.no_grad():
                    val_loss = 0
                    for ref_list, pos_list, neg_list, labels in val_data_loader:
                        ref_emb, pos_emb, neg_emb = model(ref_list, pos_list, neg_list)
                        val_loss += triplet_loss(ref_emb, pos_emb, neg_emb, min_dist_neg=margin)
                        break
                    val_loss = val_loss / len(val_dataset)

                losses['val_loss'].append(val_loss.item())

                it_bar.set_postfix(val_loss=f'{val_loss}')

            epoch_bar.update()

    torch.save(model.state_dict(), os.path.join(output_root, 'model_' +
                                                datetime.now().strftime("%m-%d-%Y_%H_%M_%S") +
                                                '.pth'))

    with open(os.path.join(output_root, 'losses.pickle'), 'wb') as f:
        pickle.dump(losses, f)

    with open(os.path.join(output_root, 'time.txt'), 'a') as f:
        f.write("END TIME: " + datetime.now().strftime("%m-%d-%Y_%H:%M:%S") + '\n')

    return model, losses
示例#2
0
def train():
    train_data_loader = DatasetLoader(params.image_feat_path_train, params.sent_feat_path_train, params.label_path_train)
    im_feat_dim = train_data_loader.im_feat_shape
    sent_feat_dim = train_data_loader.sent_feat_shape
    train_num_samples = train_data_loader.no_samples
    steps_per_epoch = train_num_samples // params.batchSize
    num_steps = steps_per_epoch * params.maxEpoch
    
    with tf.Graph().as_default():
        # Setup placeholders for input variables.
        image_placeholder = tf.placeholder(dtype=tf.float32, shape=[params.batchSize, im_feat_dim])
        sent_placeholder = tf.placeholder(dtype=tf.float32, shape=[params.batchSize, sent_feat_dim])
        label_placeholder = tf.placeholder(dtype=tf.int32, shape=[params.batchSize])
        # is_training_placeholder = tf.placeholder(tf.bool)
        
        # create embedding model
        image_intermediate_embed_tensor = visual_feature_embed(image_placeholder, params.embedDim, dropout_ratio=params.dropout)#, is_training=is_training_placeholder)
        sent_intermediate_embed_tensor = sent_feature_embed(sent_placeholder, params.embedDim, dropout_ratio=params.dropout)#, is_training=is_training_placeholder)

        if params.attention:
            image_intermediate_embed_tensor, sent_intermediate_embed_tensor, _ = aligned_attention(image_intermediate_embed_tensor, sent_intermediate_embed_tensor, params.embedDim, skip=params.skip)#, is_training=is_training_placeholder)
        else:
            image_intermediate_embed_tensor = tf.nn.l2_normalize(image_intermediate_embed_tensor, 1, epsilon=1e-10)
            sent_intermediate_embed_tensor = tf.nn.l2_normalize(sent_intermediate_embed_tensor, 1, epsilon=1e-10)
        # shared layers

        if params.shared:
            image_embed_tensor = shared_embed(image_intermediate_embed_tensor, params.embedDim, dropout_ratio=params.dropout)
            sent_embed_tensor = shared_embed(sent_intermediate_embed_tensor, params.embedDim, reuse=True, dropout_ratio=params.dropout)
        else:
            image_embed_tensor = image_intermediate_embed_tensor
            sent_embed_tensor = sent_intermediate_embed_tensor

        # category loss
        class_loss = category_loss(image_embed_tensor, sent_embed_tensor, label_placeholder, params.numCategories)
        
        # metric loss
        metric_loss = triplet_loss(image_embed_tensor, sent_embed_tensor, label_placeholder, params.margin)
        # metric_loss, _ = batch_all_triplet_loss(tf.concat([label_placeholder, label_placeholder], axis = 0), tf.concat([image_embed_tensor, sent_embed_tensor], axis = 0), params.margin, squared=True)
        
        # modality loss
        # modal_loss = modality_loss(image_embed_tensor, sent_embed_tensor, lambda_placeholder, params.embedDim, params.batchSize)
        
        # total loss
        
        # total_loss = tf.reduce_mean(tf.reduce_sum(tf.square(image_embed_tensor-sent_embed_tensor),1))
        emb_loss = params.categoryScale*class_loss + params.metricScale*metric_loss
        # scaled_modal_loss = params.modalityScale*modal_loss

        total_loss = params.categoryScale*class_loss + params.metricScale*metric_loss #+ params.modalityScale*modal_loss
        
        # scopes for different functions to separate learning
        t_vars = tf.trainable_variables()
        # pdb.set_trace()
        visfeat_vars = [v for v in t_vars if 'vf_' in v.name] # only visual embedding layers
        sentfeat_vars = [v for v in t_vars if 'sf_' in v.name] # only sent embedding layers
        sharedfeat_vars = [v for v in t_vars if 'se_' in v.name] # shared embedding layers
        attention_vars = [v for v in t_vars if 'att' in v.name] # only attention weights
        
        catclas_vars = [v for v in t_vars if 'cc_' in v.name]
        # modclas_vars = [v for v in t_vars if 'mc_' in v.name]
        
        tf.summary.scalar('metric loss', metric_loss)
        # tf.summary.scalar('modality loss', modal_loss)
        tf.summary.scalar('category loss', class_loss)
        tf.summary.scalar('total loss', total_loss)
        
        global_step_tensor = tf.Variable(0, trainable=False)
        
        # learning_rate = tf.train.exponential_decay(initLR, global_step_tensor, steps_per_epoch, 0.9, staircase=True)
        # optimizer  = tf.train.AdamOptimizer(initLR)
        
        # gvs = optimizer.compute_gradients(total_loss)
        # capped_gvs = [(tf.clip_by_value(grad, -1.0, 1.0), var) for grad, var in gvs]
        # train_op = optimizer.apply_gradients(capped_gvs, global_step_tensor)

        # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        # with tf.control_dependencies(update_ops):
            # train_op = optimizer.minimize(total_loss, global_step_tensor)
        
        # total_train_op = tf.train.AdamOptimizer(learning_rate=params.initLR).minimize(total_loss, global_step=global_step_tensor)
        
        emb_train_op = tf.train.AdamOptimizer(learning_rate=params.initLR).minimize(emb_loss, global_step=global_step_tensor, var_list=visfeat_vars+sentfeat_vars+sharedfeat_vars+catclas_vars+attention_vars)
        # modal_train_op = tf.train.AdamOptimizer(learning_rate=params.initLR).minimize(scaled_modal_loss, var_list=modclas_vars)

        saver = tf.train.Saver(max_to_keep=10)
        session_config = tf.ConfigProto()
        session_config.gpu_options.allow_growth = True
        summary_tensor = tf.summary.merge_all()
        
        with tf.Session(config=session_config) as sess:
            summary_writer = tf.summary.FileWriter(params.ExperimentDirectory, graph=tf.get_default_graph())
            sess.run([tf.global_variables_initializer()])
            if params.Restore:
                print('restoring checkpoint')
                saver.restore(sess, tf.train.latest_checkpoint(params.ExperimentDirectory))
            
            # i2s_mapk_all = []
            # s2i_mapk_all = []
            
            for i in range(num_steps):
            
                if i % steps_per_epoch == 0:
                    train_data_loader.shuffle_inds()
                
                im_feats, sent_feats, labels = train_data_loader.get_batch(i % steps_per_epoch, params.batchSize, image_aug = params.image_aug)
                
                feed_dict = {image_placeholder : im_feats, sent_placeholder : sent_feats, label_placeholder : labels}

                _, summary, global_step, loss_total, loss_class, loss_metric = sess.run(
                                                            [emb_train_op, summary_tensor, global_step_tensor, 
                                                            total_loss, class_loss, metric_loss], feed_dict = feed_dict)
                
                
                summary_writer.add_summary(summary, global_step)
                
                if i % params.printEvery == 0:
                    print('Epoch: %d | Step: %d | Total Loss: %f | Class Loss: %f | Metric Loss: %f' % (i // steps_per_epoch, i, loss_total, loss_class, loss_metric))
                
                if (i % (steps_per_epoch * params.saveEvery) == 0 and i > 0) or (i == num_steps-1):
                    print('Saving checkpoint at step %d' % i)
                    saver.save(sess, os.path.join(params.ExperimentDirectory, 'model.ckpt'+str(global_step)))
                             labels,
                             loss_type=config['training']['loss'],
                             margin=config['training']['margin'])

        is_pos = (labels == torch.ones(labels.shape).long().to(device)).float()
        is_neg = 1 - is_pos
        n_pos = torch.sum(is_pos)
        n_neg = torch.sum(is_neg)
        sim = compute_similarity(config, x, y)
        sim_pos = torch.sum(sim * is_pos) / (n_pos + 1e-8)
        sim_neg = torch.sum(sim * is_neg) / (n_neg + 1e-8)
    else:
        x_1, y, x_2, z = reshape_and_split_tensor(graph_vectors, 4)
        loss = triplet_loss(x_1,
                            y,
                            x_2,
                            z,
                            loss_type=config['training']['loss'],
                            margin=config['training']['margin'])

        sim_pos = torch.mean(compute_similarity(config, x_1, y))
        sim_neg = torch.mean(compute_similarity(config, x_2, z))

    graph_vec_scale = torch.mean(graph_vectors**2)
    if config['training']['graph_vec_regularizer_weight'] > 0:
        loss += (config['training']['graph_vec_regularizer_weight'] * 0.5 *
                 graph_vec_scale)

    optimizer.zero_grad()
    loss.backward(torch.ones_like(loss))  #
    nn.utils.clip_grad_value_(model.parameters(),
                              config['training']['clip_value'])
示例#4
0
def train(args):

    # Decode the tensors from tf record using tf.dataset API
    data = DataLoader(record_path=args.record_path, batch_size=args.batch_size, num_epochs=args.num_epochs)
 
    image, mask, background_image, object_image, label = data._read_mask_data()
    # Old preprocessing
    mask_not = tf.tile(tf.cast(tf.logical_not(tf.cast(mask, tf.bool)), tf.float32), [1,1,1,3])
    background_image_after = tf.multiply(image, mask_not)
    object_image_after = tf.multiply(image, mask)
    # Define the model
    model = DAML(args.base, margin=args.margin, embedding_dim=args.embedding_dim, is_training=True)

    # Build the model
    if args.model=="triplet":
        print "Training : {}".format(args.model)
        # Get the triplet embeddings
        anchor_embedding, positive_embedding, negative_embedding = model.build_triplet_model(anchor_image_placeholder, positive_image_placeholder, negative_image_placeholder)
        # L2 normalize the embeddings before loss
        anchor_embedding_l2 = tf.nn.l2_normalize(anchor_embedding, name='normalized_anchor')
        positive_embedding_l2 = tf.nn.l2_normalize(positive_embedding, name='normalized_positive')
        negative_embedding_l2 = tf.nn.l2_normalize(negative_embedding, name='normalized_negative')
        # compute the triplet loss
        total_loss, positive_distance, negative_distance = triplet_loss(anchor_embedding_l2, positive_embedding_l2, negative_embedding_l2)
        
        # Define the summaries
        tf.summary.scalar('Total Loss', total_loss)
        tf.summary.scalar('Positive-Anchor distance', positive_distance)
        tf.summary.scalar('Negative-Anchor distance', negative_distance)
        tf.summary.image('Anchor_image', anchor_image_placeholder)
        tf.summary.image('Positive_image', positive_image_placeholder)
        tf.summary.image('Negative_image', negative_image_placeholder)
        
    elif args.model=="mask-triplet":
        anchor_embedding = model.build_mask_triplet_model(image, background_image)
        # compute the triplet loss
        total_loss = model.triplet_loss(label, anchor_embedding)

        # Define the summaries
        tf.summary.scalar('Total Loss', total_loss)
        tf.summary.image('Mask', mask)
        tf.summary.image('Anchor Image', image)
        tf.summary.image('Object image', object_image)
        tf.summary.image('Background Image', background_image)
        
    elif args.model=="object_whole":
        whole_embedding, object_embedding = model.build_object_whole_triplet_model(image, object_image)
        anchor_embedding = whole_embedding + object_embedding
        # compute the triplet loss
        total_loss = model.triplet_loss(label, anchor_embedding)

        # Define the summaries
        tf.summary.scalar('Total Loss', total_loss)
        tf.summary.image('Mask', mask)
        tf.summary.image('Anchor Image', image)
        tf.summary.image('Object image', object_image)
        tf.summary.image('Background Image', background_image)
        
    elif args.model=="object_whole_separate":
        whole_embedding, object_embedding = model.build_object_whole_triplet_model(image, object_image_after)
        
        # compute the triplet loss
        whole_loss = model.triplet_loss(label, whole_embedding)
        object_loss = model.triplet_loss(label, object_embedding)
        total_loss = whole_loss + object_loss
        # Define the summaries
        # pdb.set_trace()
        tf.summary.scalar('Total Loss', total_loss)
        tf.summary.scalar('whole_loss', whole_loss)
        tf.summary.scalar('object_loss', object_loss)
        tf.summary.image('Mask', mask)
        tf.summary.image('Anchor Image', image)
        tf.summary.image('Object image', object_image_after)
        
        
    elif args.model=="triplet_single":
        print "Training : {}".format(args.model)
        # Get the anchor embeddings
        anchor_features = model.feature_extractor(object_image)
        anchor_embedding = model.build_embedding(anchor_features)

        # compute the lifted loss
        total_loss = model.triplet_loss(label, anchor_embedding)

        # Define the summaries
        tf.summary.scalar('Total Loss', total_loss)
        tf.summary.image('Anchor_image', image)
        tf.summary.image('Object image', object_image)
        tf.summary.image('Background_image', background_image)
        tf.summary.image('Mask', mask)
        
    elif args.model=="triplet_mask":
        print "Training : {}".format(args.model)
        # Get the anchor embeddings
        coord_conv_anchor = tf.placeholder(shape=[args.batch_size, 224, 224, 4], dtype=tf.float32, name='anchor_images')
        anchor_features = model.feature_extractor(coord_conv_anchor)
        anchor_embedding = model.build_embedding(anchor_features)

        # compute the lifted loss
        total_loss = model.triplet_loss(label_placeholder, anchor_embedding)

        # Define the summaries
        tf.summary.scalar('Total Loss', total_loss)
        tf.summary.image('Anchor_image', image)
        tf.summary.image('Anchor_image', mask)
        
    elif args.model=="lifted_single":
        print "Training : {}".format(args.model)
        # Get the anchor embeddings
        anchor_features = model.feature_extractor(anchor_image_placeholder)
        anchor_embedding = model.build_embedding(anchor_features)
        
        # compute the lifted loss
        total_loss = model.lifted_loss(label_placeholder, anchor_embedding)
        
        # Define the summaries
        tf.summary.scalar('Total Loss', total_loss)
        tf.summary.image('Anchor_image', anchor_image_placeholder)
        
    elif args.model=="lifted":
        print "Training : {}".format(args.model)
        anchor_embedding, positive_embedding, negative_embedding = model.build_triplet_model(anchor_image_placeholder, positive_image_placeholder, negative_image_placeholder)
        concat_embeddings = tf.concat([anchor_embedding, positive_embedding, negative_embedding], axis=0)
        positive_mask_placeholder = tf.placeholder(shape=[3*args.batch_size, 3*args.batch_size], dtype=tf.bool)
        # compute the lifted loss
        total_loss, _ = lifted_struct_loss(positive_mask_placeholder, concat_embeddings)
        
        # Define the summaries
        tf.summary.scalar('Total Loss', total_loss)
        tf.summary.image('Anchor_image', anchor_image_placeholder)
        tf.summary.image('Positive_image', positive_image_placeholder)
        tf.summary.image('Negative_image', negative_image_placeholder)
        
    elif args.model=="daml-lifted":
        print "Training : {}".format(args.model)
        anchor_embedding, positive_embedding, negative_embedding = model.build_triplet_model(anchor_image_placeholder, positive_image_placeholder, negative_image_placeholder)
        concat_embeddings = tf.concat([anchor_embedding, positive_embedding, negative_embedding], axis=0)
        positive_mask_placeholder = tf.placeholder(shape=[3*args.batch_size, 3*args.batch_size], dtype=tf.bool)
        # compute the lifted loss
        lifted_loss_t, _ = lifted_struct_loss(positive_mask_placeholder, concat_embeddings)
        
        # Get the synthetic embeddings
        synthetic_neg_embedding = model.generator(anchor_embedding, positive_embedding, negative_embedding)
        
        # L2 normalize the embeddings before loss
        anchor_embedding_l2 = tf.nn.l2_normalize(anchor_embedding, name='normalized_anchor')
        positive_embedding_l2 = tf.nn.l2_normalize(positive_embedding, name='normalized_positive')
        negative_embedding_l2 = tf.nn.l2_normalize(negative_embedding, name='normalized_negative')
        synthetic_neg_embedding_l2 = tf.nn.l2_normalize(synthetic_neg_embedding, name='normalized_synthetic')
        J_hard, J_reg, J_adv = model.daml_loss(anchor_embedding_l2, positive_embedding_l2, negative_embedding_l2, synthetic_neg_embedding_l2)
        J_gen = args.hard_weight*J_hard + args.reg_weight*J_reg + args.adv_weight*J_adv
        total_loss = args.metric_weight*lifted_loss_t + J_gen
        # Define the summaries
        tf.summary.scalar('J_m', lifted_loss_t)
        tf.summary.scalar('J_hard', J_hard)
        tf.summary.scalar('J_reg', J_reg)
        tf.summary.scalar('J_adv', J_adv)
        tf.summary.scalar('J_gen', J_gen)
        tf.summary.scalar('Total Loss', total_loss)
        tf.summary.image('Anchor_image', anchor_image_placeholder)
        tf.summary.image('Positive_image', positive_image_placeholder)
        tf.summary.image('Negative_image', negative_image_placeholder)
        
    elif args.model=="daml-triplet":
        # Get the triplet embeddings
        anchor_embedding, positive_embedding, negative_embedding = model.build_triplet_model(anchor_image_placeholder, positive_image_placeholder, negative_image_placeholder)
        positive_mask_placeholder = tf.placeholder(shape=[3*args.batch_size, 3*args.batch_size], dtype=tf.bool)
        # Get the synthetic embeddings
        synthetic_neg_embedding = model.generator(anchor_embedding, positive_embedding, negative_embedding)
        # L2 normalize the embeddings before loss
        anchor_embedding_l2 = tf.nn.l2_normalize(anchor_embedding, name='normalized_anchor')
        positive_embedding_l2 = tf.nn.l2_normalize(positive_embedding, name='normalized_positive')
        negative_embedding_l2 = tf.nn.l2_normalize(negative_embedding, name='normalized_negative')
        synthetic_neg_embedding_l2 = tf.nn.l2_normalize(synthetic_neg_embedding, name='normalized_synthetic')
        # Calculate Triplet loss
        triplet_loss_t, positive_distance, negative_distance = triplet_loss(anchor_embedding_l2, positive_embedding_l2, synthetic_neg_embedding_l2)
        J_hard, J_reg, J_adv = model.daml_loss(anchor_embedding_l2, positive_embedding_l2, negative_embedding_l2, synthetic_neg_embedding_l2)
        J_gen = args.hard_weight*J_hard + args.reg_weight*J_reg + args.adv_weight*J_adv
        total_loss = args.metric_weight*triplet_loss_t + J_gen
        # Define the summaries
        tf.summary.scalar('J_m', triplet_loss_t)
        tf.summary.scalar('J_hard', J_hard)
        tf.summary.scalar('J_reg', J_reg)
        tf.summary.scalar('J_adv', J_adv)
        tf.summary.scalar('J_gen', J_gen)
        tf.summary.scalar('Total Loss', total_loss)
        tf.summary.scalar('Positive-Anchor distance', positive_distance)
        tf.summary.scalar('Negative-Anchor distance', negative_distance)
        tf.summary.image('Anchor_image', anchor_image_placeholder)
        tf.summary.image('Positive_image', positive_image_placeholder)
        tf.summary.image('Negative_image', negative_image_placeholder)


    # Get the training op for the whole network.
    train_op, initial_saver, global_step = get_training_op(total_loss, args)

    #Merge summaries
    summary_tensor = tf.summary.merge_all()

    now = datetime.datetime.now()
    summary_dir_name = args.exp_path+'/s_'+args.model+'_'+args.mode+'_'+now.strftime("%Y-%m-%d_%H_%M")
    checkpoint_dir_name = args.exp_path+'/c_'+args.model+'_'+args.mode+'_'+now.strftime("%Y-%m-%d_%H_%M")
    if args.mode=='only_gen':
        summary_dir_name = args.exp_path+'/gen_summaries_'+args.model+'_'+args.mode+'_'+now.strftime("%Y-%m-%d_%H_%M")
        checkpoint_dir_name = args.exp_path+'/gen_checkpoints_'+args.model+'_'+args.mode+'_'+now.strftime("%Y-%m-%d_%H_%M")
    summary_filewriter = tf.summary.FileWriter(summary_dir_name, tf.get_default_graph())

    # Checkpoint saver to save the variables of the entire graph. Training monitored session handles queue runners internally.
    checkpoint_saver = tf.train.Saver(keep_checkpoint_every_n_hours=1.0)
    checkpoint_saver_hook = tf.train.CheckpointSaverHook(saver=checkpoint_saver, checkpoint_dir=checkpoint_dir_name, save_steps=args.save_steps)
    with tf.train.MonitoredTrainingSession(hooks=[checkpoint_saver_hook]) as sess:
        #Restore the feature_extractor checkpoint
        initial_saver.restore(sess, args.checkpoint)
        print "Restored: {}".format(args.checkpoint)
        while not sess.should_stop():
            try:
                start_time = time.time()
                # Get a batch of input pairs which are positive
                # image_np, mask_np, label_np = sess.run([image, mask, label])										
                # top_image_np, bottom_image_np, label_np, pos_flag_np = sess.run([top_image, bottom_image, label, pos_flag])		

                # Create positive and negative pairing 
                # anchor_image_b, positive_image_b, negative_image_b,  \
                            # pos_labels_b, neg_labels_b, pos_flag_b, neg_flag_b, adjacency, positive_mask = permutate(top_image_np, bottom_image_np, label_np, pos_flag_np)
                # Run the training op
                # _, global_step_value, total_loss_value, summary_value =  sess.run([train_op, global_step, total_loss, summary_tensor], 
                                                                # feed_dict={anchor_image_placeholder: anchor_image_b,
                                                                           # positive_image_placeholder: positive_image_b,
                                                                           # negative_image_placeholder: negative_image_b,
                                                                           # positive_mask_placeholder: adjacency
                                                                           # })
                                                                           
                # Run the training op
                _, global_step_value, total_loss_value, summary_value =  sess.run([train_op, global_step, total_loss, summary_tensor])
                                                                           
                # post_mask = process_mask(mask_np)
                # coord_conv_batch = np.concatenate([image_np, post_mask], axis=3) #, row_vec_batch, col_vec_batch
                # Run the training op
                # _, global_step_value, total_loss_value, summary_value =  sess.run([train_op, global_step, total_loss, summary_tensor], 
                                                                # feed_dict={coord_conv_anchor: coord_conv_batch,
                                                                           # label_placeholder: label_np})
                if (global_step_value+1)%100 == 0:
                    iter_time = time.time() - start_time
                    print 'Iteration: {} Loss: {} Step time: {}'.format(global_step_value+1, total_loss_value, iter_time)
                    summary_filewriter.add_summary(summary_value, global_step_value)
                
            except tf.errors.OutOfRangeError:
                break
                
        print "Training completed"
示例#5
0
文件: main.py 项目: gxdai/FCS_pytorch
def train(args):
    # basic arguments.
    ngpu = args.ngpu
    margin = args.margin
    manual_seed = args.manual_seed
    torch.manual_seed(manual_seed)
    mean_value = args.mean_value
    std_value = args.std_value
    print("margin = {:5.2f}".format(margin))
    print("manual_seed = {:5.2f}".format(manual_seed))
    print("mean_value = {:5.2f}".format(mean_value))
    print("std_value = {:5.2f}".format(std_value))
    num_epochs = args.num_epochs
    train_batch_size = args.train_batch_size
    test_batch_size = args.test_batch_size
    gamma = args.gamma # for learning rate decay
    learning_rate = args.learning_rate
    learning_rate2 = args.learning_rate2


    loss_type = args.loss_type
    dataset_name = args.dataset_name
    pair_type = args.pair_type
    mode = args.mode
    weight_file = args.weight_file
    print("pair_type = {}".format(pair_type))
    print("loss_type = {}".format(loss_type))
    print("mode = {}".format(mode))
    print("weight_file = {}".format(weight_file))

    root_dir = args.root_dir
    image_txt = args.image_txt
    train_test_split_txt = args.train_test_split_txt
    label_txt = args.label_txt
    ckpt_dir = args.ckpt_dir
    eval_step = args.eval_step
    display_step = args.display_step
    embedding_size = args.embedding_size


    pretrained = args.pretrained
    aux_logits = args.aux_logits
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    kargs = {'ngpu': ngpu, 'pretrained': pretrained, 'aux_logits':aux_logits, 'embedding_size': embedding_size}

    # create directory
    model_dir = os.path.join(ckpt_dir, dataset_name, loss_type, str(int(embedding_size)))
    print("model_dir = {}".format(model_dir))
    if not os.path.isdir(model_dir):
        os.makedirs(model_dir)
    # network and loss
    siamese_network = SiameseNetwork(**kargs)


    first_group, second_group = siamese_network.separate_parameter_group()

    param_lr_dict = [
               {'params': first_group, 'lr': learning_rate2},
               {'params': second_group, 'lr': learning_rate}
              ]

    gpu_number = torch.cuda.device_count()
    if device.type == 'cuda' and gpu_number > 1:
        siamese_network = nn.DataParallel(siamese_network, list(range(torch.cuda.device_count())))
    siamese_network.to(device)

    # contrastive_loss = ContrastiveLoss(margin=margin)

    # params = siamese_network.parameters()

    print("args.optimizer = {:10s}".format(args.optimizer))
    print("learning_rate = {:5.5f}".format(learning_rate))
    print("learning_rate2 = {:5.5f}".format(learning_rate2))
    optimizer = configure_optimizer(param_lr_dict, optimizer=args.optimizer)

    # using different lr
    # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma, last_epoch=-1)


    transform = transforms.Compose([transforms.Resize((299, 299)),
                                    transforms.CenterCrop(299),
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
                                  )

    if dataset_name == 'cub200':
        """
        print("dataset_name = {:10s}".format(dataset_name))
        print(root_dir)
        print(image_txt)
        print(train_test_split_txt)
        print(label_txt)
        """
        dataset_train = CubDataset(root_dir, image_txt, train_test_split_txt, label_txt, transform=transform, is_train=True, offset=1)
        dataset_eval = CubDataset(root_dir, image_txt, train_test_split_txt, label_txt, transform=transform, is_train=False, offset=1)
    elif dataset_name == 'online_product':
        """
        print("dataset_name = {:10s}".format(dataset_name))
        """
        dataset_train = OnlineProductDataset(root_dir, train_txt=image_txt, test_txt=train_test_split_txt, transform=transform, is_train=True, offset=1)
        dataset_eval = OnlineProductDataset(root_dir, train_txt=image_txt, test_txt=train_test_split_txt, transform=transform, is_train=False, offset=1)
    elif dataset_name == "car196":
        print("dataset_name = {}".format(dataset_name))
        dataset_train = CarDataset(root_dir, image_info_mat=image_txt, transform=transform, is_train=True, offset=1)
        dataset_eval = CarDataset(root_dir, image_info_mat=image_txt, transform=transform, is_train=False, offset=1)


    dataloader = DataLoader(dataset=dataset_train, batch_size=train_batch_size, shuffle=False, num_workers=4)
    dataloader_eval = DataLoader(dataset=dataset_eval, batch_size=test_batch_size, shuffle=False, num_workers=4)

    log_for_loss = []

    if mode == 'evaluation':
        print("Do one time evluation and exit")
        print("Load pretrained model")
        siamese_network.module.load_state_dict(torch.load(weight_file))
        print("Finish loading")
        print("Calculting features")
        feature_set, label_set, path_set = get_feature_and_label(siamese_network, dataloader_eval, device)
        rec_pre = evaluation(feature_set, label_set)
        # np.save("car196_rec_pre_ftl.npy", rec_pre)
        # for visualization
        sum_dict = {'feature': feature_set, 'label': label_set, 'path': path_set}
        np.save('car196_fea_label_path.npy', sum_dict)
        sys.exit()
    print("Finish eval")

    for epoch in range(num_epochs):
        if epoch == 0:
            feature_set, label_set, _ = get_feature_and_label(siamese_network, dataloader_eval, device)
            # distance_type: Euclidean or cosine
            rec_pre = evaluation(feature_set, label_set, distance_type='cosine')
        siamese_network.train()
        for i, data in enumerate(dataloader, 0):
            # img_1, img_2, sim_label = data['img_1'].to(device), data['img_2'].to(device), data['sim_label'].type(torch.FloatTensor).to(device)
            img_1, img_2, label_1, label_2 = data['img_1'].to(device), data['img_2'].to(device), data['label_1'].to(device), data['label_2'].to(device)
            optimizer.zero_grad()
            output_1, output_2 = siamese_network(img_1, img_2)
            pair_dist, pair_sim_label = calculate_distance_and_similariy_label(output_1, output_2, label_1, label_2, sqrt=True, pair_type=pair_type)
            if loss_type == "contrastive_loss":
                loss, positive_loss, negative_loss = contrastive_loss(pair_dist, pair_sim_label, margin)
            elif loss_type == "focal_contrastive_loss":
                loss, positive_loss, negative_loss = focal_contrastive_loss(pair_dist, pair_sim_label, margin, mean_value, std_value)
            elif loss_type == "triplet_loss":
                loss, positive_loss, negative_loss = triplet_loss(pair_dist, pair_sim_label, margin)
            elif loss_type == "focal_triplet_loss":
                loss, positive_loss, negative_loss = focal_triplet_loss(pair_dist, pair_sim_label, margin, mean_value, std_value)
            elif loss_type == "angular_loss":
                center_output = (output_1 + output_2)/2.
                pair_dist_2, _ = calculate_distance_and_similariy_label(center_output, output_2, label_1, label_2, sqrt=True, pair_type=pair_type)
                # angle margin is 45^o
                loss, positive_loss, negative_loss = angular_loss(pair_dist, pair_dist_2, pair_sim_label, 45)
            else:
                print("Unknown loss function")
                sys.exit()

            # try my own customized loss function
            # loss = contrastive_loss(output_1, output_2, pair_sim_label)
            loss.backward()
            optimizer.step()
            log_for_loss.append(loss.detach().item())
            if i % display_step == 0 and i > 0:
                print("{}, Epoch [{:3d}/{:3d}], Iter [{:3d}/{:3d}], Loss: {:6.5f}, Positive loss: {:6.5f}, Negative loss: {:6.5f}".format(
                      datetime.datetime.now(), epoch, num_epochs, i, len(dataloader), loss.item(), positive_loss.item(), negative_loss.item()))
        if epoch % eval_step == 0:
            print("Start evalution")
            # np.save(loss_type +'.npy', log_for_loss)
            feature_set, label_set, _ = get_feature_and_label(siamese_network, dataloader_eval, device)
            # distance_type: Euclidean or cosine
            rec_pre = evaluation(feature_set, label_set, distance_type='cosine')
            torch.save(siamese_network.module.state_dict(), os.path.join(model_dir, 'model_' + str(epoch) +'_.pth'))