Esempio n. 1
0
def main():
    global opt
    # test data loader
    test_video_loader = torch.utils.data.DataLoader(test_video_dataset,
                                                    batch_size=opt.batchSize,
                                                    shuffle=False,
                                                    num_workers=int(
                                                        opt.workers))
    test_audio_loader = torch.utils.data.DataLoader(test_audio_dataset,
                                                    batch_size=opt.batchSize,
                                                    shuffle=False,
                                                    num_workers=int(
                                                        opt.workers))

    # create model
    model = models.VAMetric2()

    if opt.init_model != '':
        print('loading pretrained model from {0}'.format(opt.init_model))
        model.load_state_dict(torch.load(opt.init_model))

    if opt.cuda:
        print('shift model to GPU .. ')
        model = model.cuda()

    test(test_video_loader, test_audio_loader, model, opt)
Esempio n. 2
0
def main():
    global opt
    # test data loader
    test_video_loader = torch.utils.data.DataLoader(test_video_dataset, batch_size=opt.batchSize,
                                     shuffle=False, num_workers=int(opt.workers))
    test_audio_loader = torch.utils.data.DataLoader(test_audio_dataset, batch_size=opt.batchSize,
                                     shuffle=False, num_workers=int(opt.workers))
    # create model
    if opt.model is 'VAMetric':
        model = models.VAMetric()
    elif opt.model is 'VAMetric2':
        model = models.VAMetric2()
    else:
        model = models.VAMetric()
        opt.model = 'VAMetric'

    if opt.init_model != '':
        print('loading pretrained model from {0}'.format(opt.init_model))
        model.load_state_dict(torch.load(opt.init_model))
    else:
        raise IOError('Please add your pretrained model path to init_model in config file!')

    if opt.cuda:
        print('shift model to GPU .. ')
        model = model.cuda()

    test(test_video_loader, test_audio_loader, model, opt)
Esempio n. 3
0
def main():
    global opt
    # train data loader
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opt.batchSize,
                                               shuffle=True,
                                               num_workers=int(opt.workers))

    # create model
    if opt.model is 'VAMetric':
        model = models.VAMetric()
    elif opt.model is 'VAMetric2':
        model = models.VAMetric2()
    elif opt.model is 'VAMetric3':
        model = models.VAMetric3()
    else:
        model = models.VA_Linear()
        opt.model = 'VA_Linear'

    if opt.init_model != '':
        print('loading pretrained model from {0}'.format(opt.init_model))
        model.load_state_dict(torch.load(opt.init_model))

    # Contrastive Loss
    criterion = models.ContrastiveLoss()
    #criterion = nn.BCELoss()

    if opt.cuda:
        print('shift model and criterion to GPU .. ')
        model = model.cuda()
        criterion = criterion.cuda()

    # optimizer
    #optimizer = optim.SGD(model.parameters(), opt.lr,
    #momentum=opt.momentum,
    #weight_decay=opt.weight_decay)
    optimizer = optim.Adam(model.parameters())
    # adjust learning rate every lr_decay_epoch
    #lambda_lr = lambda epoch: opt.lr_decay ** ((epoch + 1) // opt.lr_decay_epoch)   #poly policy

    for epoch in range(opt.max_epochs):
        #################################
        # train for one epoch
        #################################
        train(train_loader, model, criterion, optimizer, epoch, opt)
        #LR_Policy(optimizer, opt.lr, lambda_lr(epoch))      # adjust learning rate through poly policy

        ##################################
        # save checkpoints
        ##################################

        # save model every 10 epochs
        if ((epoch + 1) % opt.epoch_save) == 0:
            path_checkpoint = '{0}/{1}_state_epoch{2}.pth'.format(
                opt.checkpoint_folder, opt.model, epoch + 1)
            utils.save_checkpoint(model, path_checkpoint)
Esempio n. 4
0
def main():
    global opt
    best_prec1 = 0
    # only used when we resume training from some checkpoint model
    resume_epoch = 0
    # train data loader
    # for loader, droplast by default is set to false
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opt.batchSize,
                                               shuffle=True,
                                               num_workers=int(opt.workers))

    # create model
    model = models.VAMetric2()

    if not opt.train and opt.init_model != '':
        print('loading pretrained model from {0}'.format(opt.init_model))
        model.load_state_dict(torch.load(opt.init_model))

    # Contrastive Loss
    criterion = models.ContrastiveLoss()

    if opt.cuda:
        print('shift model and criterion to GPU .. ')
        model = model.cuda()
        criterion = criterion.cuda()

    # optimizer
    optimizer = optim.SGD(model.parameters(),
                          opt.lr,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay)

    # adjust learning rate every lr_decay_epoch
    lambda_lr = lambda epoch: opt.lr_decay**(
        (epoch + 1) // opt.lr_decay_epoch)  #poly policy
    scheduler = LR_Policy(optimizer, lambda_lr)

    for epoch in range(resume_epoch, opt.max_epochs):
        #################################
        # train for one epoch
        #################################
        train(train_loader, model, criterion, optimizer, epoch, opt)
        scheduler.step()

        ##################################
        # save checkpoints
        ##################################

        # save model every 10 epochs
        if ((epoch + 1) % opt.epoch_save) == 0:
            path_checkpoint = '{0}/{1}_state_epoch{2}.pth'.format(
                opt.checkpoint_folder, opt.prefix, epoch + 1)
            utils.save_checkpoint(model.state_dict(), path_checkpoint)