Esempio n. 1
0
def main(args):
    model_s1 = custom_model.buildModel(args)
    model_s2 = custom_model.buildModelStage2(args)
    model_s3 = custom_model.buildModelStage3(args)
    model_s0 = custom_model.buildModelStage0(args)
    models = [model_s1, model_s2, model_s3, model_s0]

    #optimizer, scheduler, records = solver_utils.configMultiOptimizer(args, models)
    optimizer, scheduler, records = solver_utils.configMultiOptimizer(
        args, [models[1], models[2]])
    optimizers = [optimizer, -1]
    criterion = solver_utils.Stage4Crit(args)
    recorder = recorders.Records(args.log_dir, records)

    train_loader, val_loader = custom_data_loader.pokemonDataloader(args)

    for epoch in range(args.start_epoch, args.epochs + 1):
        scheduler.step()

        recorder.insertRecord('train', 'lr', epoch, scheduler.get_lr()[0])

        train_utils.train(args, train_loader, models, criterion, optimizers,
                          log, epoch, recorder)
        if epoch % args.save_intv == 0:
            model_utils.saveMultiCheckpoint(args.cp_dir, epoch, models,
                                            optimizer, recorder.records, args)
        #log.plotCurves(recorder, 'train')

        if epoch % args.val_intv == 0:
            test_utils.test(args, 'val', val_loader, models, log, epoch,
                            recorder)
Esempio n. 2
0
def main(args):
    model = custom_model.buildModel(args)
    model_s2 = custom_model.buildModelStage2(args)
    models = [model, model_s2]

    optimizer, scheduler, records = solver_utils.configOptimizer(
        args, model_s2)
    optimizers = [optimizer, -1]
    criterion = solver_utils.Stage2Crit(args)
    recorder = recorders.Records(args.log_dir, records)

    train_loader, val_loader = custom_data_loader.customDataloader(args)

    for epoch in range(args.start_epoch, args.epochs + 1):
        scheduler.step()

        recorder.insertRecord('train', 'lr', epoch, scheduler.get_lr()[0])

        train_utils.train(args, train_loader, models, criterion, optimizers,
                          log, epoch, recorder)
        if epoch % args.save_intv == 0:
            model_utils.saveCheckpoint(args.cp_dir, epoch, model_s2, optimizer,
                                       recorder.records, args)
        log.plotCurves(recorder, 'train')

        if epoch % args.val_intv == 0:
            test_utils.test(args, 'val', val_loader, models, log, epoch,
                            recorder)
            log.plotCurves(recorder, 'val')
Esempio n. 3
0
def main(args):
    test_loader = custom_data_loader.benchmarkLoader(args)
    model = custom_model.buildModel(args)
    model_s2 = custom_model.buildModelStage2(args)
    models = [model, model_s2]

    recorder = recorders.Records(args.log_dir)
    test_utils.test(args, 'test', test_loader, models, log, 1, recorder)