Exemplo n.º 1
0


if opt.evaluate:
    logging('evaluating ...')
    test(0)
else:
    for epoch in range(opt.begin_epoch, opt.end_epoch + 1):
        # Train the model for 1 epoch
        train(epoch)

        # Validate the model
        fscore = test(epoch)

        is_best = fscore > best_fscore
        if is_best:
            logging(f"New best fscore is achieved: {fscore}" )
            logging(f"Previous fscore was: {best_fscore}")
            best_fscore = fscore

        # Save the model to backup directory
        state = {
            'wandb_id': wandb_id,
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'fscore': fscore
            }
        save_checkpoint(state, is_best, backupdir, opt.dataset, clip_duration, epoch)
        logging('Weights are saved to backup directory: %s' % (backupdir))
Exemplo n.º 2
0
                      momentum=momentum,
                      dampening=0,
                      weight_decay=decay * batch_size)

kwargs = {'num_workers': num_workers, 'pin_memory': True} if use_cuda else {}

# Load resume path if necessary
if opt.resume_path:
    print(
        "===================================================================")
    print('loading checkpoint {}'.format(opt.resume_path))
    checkpoint = torch.load(opt.resume_path)
    opt.begin_epoch = checkpoint['epoch']
    best_fscore = checkpoint['fscore']
    pretrained_dict = checkpoint['state_dict']
    model_dict = model.state_dict()
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items()
        if k in model_dict and k != 'module.cfam.conv_bn_relu1.0.weight'
    }
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    # model.load_state_dict(checkpoint['state_dict'])
    #optimizer.load_state_dict(checkpoint['optimizer'])
    model.seen = checkpoint['epoch'] * nsamples
    print("Loaded model fscore: ", checkpoint['fscore'])
    print(
        "===================================================================")

region_loss.seen = model.seen