def main():
    if not os.path.isfile(CHECKPOINT):
        print('ERROR: Checkpoint file "%s" not found' % CHECKPOINT)
        return

    # Read config file of the model
    args = get_args()
    print(args)
    torch.manual_seed(1)
    random.seed(1)
    np.random.seed(1)

    # reset some arguments
    args.add_jitter_bbox = None
    args.add_jitter_layout = None
    args.add_jitter_feats = None
    args.batch_size = BATCH_SIZE
    args.test_h5 = SPLIT
    device = torch.device("cuda:0")  #torch.cuda.set_device(GPU)

    # Load the model, with a bit of care in case there are no GPUs
    map_location = 'cpu' if device == torch.device('cpu') else None
    checkpoint = torch.load(CHECKPOINT, map_location=map_location)

    if not PRECOMPUTED:
        # initialize model and load checkpoint
        kwargs = checkpoint['model_kwargs']

        model = Model(**kwargs)
        model.load_state_dict(checkpoint['model_state'])
        model.eval()
        model.to(device)

        # create data loaders
        _, train_loader, val_loader, test_loader = build_loaders(
            args, evaluating=True)

        # testing model
        print('Batch size: ', BATCH_SIZE)
        print('Evaluating on {} set'.format(SPLIT))
        eval_model(args,
                   model,
                   test_loader,
                   device,
                   use_gt=USE_GT,
                   use_feats=USE_FEATS,
                   filter_box=IGNORE_SMALL)
        # losses, samples, avg_iou = results
    else:
        # sample images and scores already computed while training (only one batch)
        samples = checkpoint['val_samples'][-1]  # get last iteration
        original_img = samples['gt_img'].cpu().numpy()
        predicted_img = samples['gt_box_pred_mask'].cpu().numpy()

    return
예제 #2
0
파일: train.py 프로젝트: LUGUANSONG/i2g2i
                                            result.ac_loss_fake)
            trainer.train_image_discriminator(result.loss_d_fake_img,
                                              result.loss_d_wrong_texture,
                                              result.loss_D_real)

            if t % args.print_every == 0 or t == 1:
                trainer.write_losses(checkpoint, t)
                trainer.write_images(t, imgs, imgs_pred, layout_pred_one_hot, layout_pred_one_hot, \
                                     d_real_crops, d_fake_crops)

            if t % args.checkpoint_every == 0:
                print('begin check model train')
                train_results = check_model(args,
                                            val_loader,
                                            trainer,
                                            inception_score,
                                            use_gt=True)
                print('begin check model val')
                val_results = check_model(args,
                                          val_loader,
                                          trainer,
                                          inception_score,
                                          use_gt=False)
                trainer.save_checkpoint(checkpoint, t, args, epoch,
                                        train_results, val_results)


if __name__ == '__main__':
    args = get_args()
    main(args)