for param in model.frontend.parameters():
        param.requires_grad = False
    for param in model.resnet.parameters():
        param.requires_grad = False
else:
    print_log('\n\nno freezing', log=args.logfile)

print_log('loading data', log=args.logfile)
if args.train: trainer = Trainer(args)
if args.val: validator = Validator(args)

if args.train:
    loss_history_train, loss_history_val = [], []
    accu_history_train, accu_history_val = [], []
    for epoch in range(args.start_epoch, args.end_epoch):
        loss_train, accu_train = trainer.epoch(model, epoch)
        if args.val: loss_val, accu_val = validator.epoch(model, epoch)

        # plot figure
        loss_history_train.append([loss_train])
        loss_history_val.append([loss_val])
        accu_history_train.append([accu_train])
        accu_history_val.append([accu_val])
        plot_loss(loss_history_train, loss_history_val, save_dir=args.save_dir)
        plot_accu(accu_history_train, accu_history_val, save_dir=args.save_dir)

if args.test:
    tester = Tester(args)
    tester.epoch(model)
    if args.val: validator.epoch(model, epoch=0)
Exemple #2
0
        model = LipRead(options)
        load_model(model, state_dict, grad_states) # load weights and freeze states
        last_epoch = states["epoch"]

    
    # save current options
    with open(os.path.join(result_dir, "options_used.toml"), 'w') as f:
        toml.dump(options, f)


    if(options["general"]["usecudnnbenchmark"] and options["general"]["usecudnn"]):
        print("Running cudnn benchmark...")
        torch.backends.cudnn.benchmark = True


    if options["training"]["train"]:
        trainer = Trainer(options)
    if(options["validation"]["validate"]):
        validator = Validator(options, result_dir)


    final_epoch = args.final_epoch if args.final_epoch is not None else options["training"]["epochs"]
    for epoch in range(last_epoch + 1, final_epoch):
        loss = trainer.epoch(model, epoch) if options["training"]["train"] else ''
        accuracy = validator.epoch(model, epoch) if options["validation"]["validate"] else ''
        csv.add(epoch, accuracy=accuracy, loss=loss)
        save_checkpoint(result_dir, epoch, model, options=options)

    # save the final model 
    torch.save(model.state_dict(), os.path.join(result_dir, "epoch{}.pt".format(epoch)))
print_log('creating the model', log=options["general"]["logfile"])
model = LipRead(options)

print_log('loading model', log=options["general"]["logfile"])
if options["general"]["loadpretrainedmodel"]:
    print_log('loading the pretrained model at %s' %
              options["general"]["pretrainedmodelpath"],
              log=options["general"]["logfile"])
    model.load_state_dict(torch.load(
        options["general"]["pretrainedmodelpath"]))  #Create the model.
if options["general"]["usecudnn"]:
    model = model.cuda(
        options["general"]["gpuid"])  #Move the model to the GPU.

print_log('loading data', log=options["general"]["logfile"])
if options["training"]["train"]: trainer = Trainer(options)
if options["validation"]["validate"]:
    validator = Validator(options)
    validator.epoch(model, epoch=0)

if options["training"]["train"]:
    for epoch in range(options["training"]["startepoch"],
                       options["training"]["endepoch"]):
        trainer.epoch(model, epoch)
        if options["validation"]["validate"]: validator.epoch(model, epoch)
    # if options["testing"]["test"]:
    # 	tester = Tester(options)
    # 	tester.epoch(model)

options["general"]["logfile"].close()
Exemple #4
0
     training = Trainer(
         cuda=cuda,
         cfg=cfg,
         model_depth=model_depth,
         model_rgb=model_rgb,
         model_fusion=model_fusion,
         train_loader=train_loader,
         test_data_list = ["DUT-RGBD","NJUD","NLPR","SSD","STEREO","LFSD","RGBD135","SIP","ReDWeb"],
         test_data_root = args.test_dataroot,
         salmap_root = args.salmap_root,
         outpath=args.snapshot_root,
         logging=logging,
         writer=writer,
         max_epoch=args.epoch,
     )
     training.epoch = start_epoch
     training.iteration = start_iteration
     training.train()
 else:
     # -------------------------- inference --------------------------- #
     res = []
     for id, (data, depth, bins, img_name, img_size) in enumerate(test_loader):
         # print('testing bach %d' % id)
         inputs = Variable(data).cuda()
         depth = Variable(depth).cuda()
         bins = Variable(bins).cuda()
         n, c, h, w = inputs.size()
         depth = depth.view(n, 1, h, w).repeat(1, c, 1, 1)
         torch.cuda.synchronize()
         start = time.time()
         with torch.no_grad():