コード例 #1
0
class Triplet(nn.Module):
    def __init__(self, dropout=True, normalization=True):
        super(Triplet, self).__init__()
        self.siamese = Siamese(dropout=dropout, normalization=normalization)

        linear_layer = nn.Linear(2, 1)

        init_weights = torch.Tensor([[50, -50]])
        init_bias = torch.Tensor([[0]])

        init_weights = init_weights.float()
        init_bias = init_bias.float()

        init_weights.requires_grad = False
        init_bias.requires_grad = False

        linear_layer.weight = torch.nn.Parameter(init_weights)
        linear_layer.bias = torch.nn.Parameter(init_bias)

        self.final_layer = nn.Sequential(linear_layer, nn.Sigmoid())

    def forward(self, query, near, far):
        # TODO: we can optimize this by only calculating the left/imitation branch once
        near_output = self.siamese(query, near)
        far_output = self.siamese(query, far)

        near_reshaped = near_output.view(len(near_output), -1)
        far_reshaped = far_output.view(len(far_output), -1)
        concatenated = torch.cat((near_reshaped, far_reshaped), dim=1)

        output = self.final_layer(concatenated)
        return output.view(-1)

    def load_siamese(self, model: nn.Module):
        self.siamese.load_state_dict(model.state_dict())
コード例 #2
0
                                          **kwargs)

manualSeed = 9302  #random.randint(1, 10000) # fix seed
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

g_config = get_config()

model_dir = args.model_dir
setupLogger(os.path.join(model_dir, 'log.txt'))
g_config.model_dir = model_dir

criterion = nn.HingeEmbeddingLoss()
model = Siamese()

# load model snapshot
load_path = args.load_path
if load_path is not '':
    snapshot = torch.load(load_path)
    # loadModelState(model, snapshot)
    model.load_state_dict(snapshot['state_dict'])
    logging('Model loaded from {}'.format(load_path))

train_model(model,
            criterion,
            train_loader,
            test_loader,
            g_config,
            use_cuda=False)
コード例 #3
0
from models.siamese import Siamese

load_from_torch7 = False

print('Loading model...')
model_dir = 'models/snapshot/'
model_load_path = os.path.join(model_dir, 'snapshot_epoch_1.pt')
gConfig = get_config()
gConfig.model_dir = model_dir

criterion = nn.HingeEmbeddingLoss()
model = Siamese()

package = torch.load(model_load_path)

model.load_state_dict(package['state_dict'])
model.eval()
print('Model loaded from {}'.format(model_load_path))

logging('Model configuration:\n{}'.format(model))

modelSize, nParamsEachLayer = modelSize(model)
logging('Model size: {}\n{}'.format(modelSize, nParamsEachLayer))

params = model.parameters()

for i, a_param in enumerate(params):
    print a_param

exit(0)
コード例 #4
0
    print('====> Test set loss: {:.4f}'.format(test_loss))
    writer.add_scalar('data/test_loss', test_loss, epoch)

def save(epoch):
    state = {'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict(),
             'epoch': epoch + 1}
    save_dir = 'checkpoints/siamese/lr_{}_decay_{}'.format(int(lr*1000), int(decay*100))
    if not os.path.isdir(save_dir):
        os.mkdir(save_dir)
    torch.save(state, os.path.join(save_dir, 'ae_{}.pth.tar'.format(epoch)))

start_epoch = 1
resume = True
if resume:
    state = torch.load('./checkpoints/best_siamese.pth.tar', map_location=lambda storage, loc: storage)
    model.load_state_dict(state['state_dict'])
    optimizer.load_state_dict(state['optimizer'])
    start_epoch = state['epoch']

print('Begin train...')
for epoch in range(start_epoch, args.epochs + 1):
    train(epoch)
    test(epoch)
    save(epoch)

    # Adjust learning rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * decay