Beispiel #1
0
def __main__():
    args = parse_args()
    if args.cuda and torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    train, validate, test, n_users, n_items, scale = loadDataset(args.dataset)
    print("n_users:", n_users, "n_items:", n_items)

    if args.model == 'nn':
        model = TestModel(n_users, n_items, 64, dropout=0.15,
                          scale=scale).to(device)
    elif args.model == 'dot':
        model = DotModel(n_users, n_items, 64, dropout=0.15,
                         scale=scale).to(device)
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters())
    load_model(model, args)

    if args.mode == 'train':
        best_loss = test_one_epoch(model, validate, loss_fn, device)
        for epoch in range(args.epochs):
            train_one_epoch(model, train, loss_fn, optimizer, device)
            loss = test_one_epoch(model, validate, loss_fn, device)
            if loss < best_loss:
                best_loss = loss
                save_model(model, args)

    loss = test_one_epoch(model, test, loss_fn, device, mode='test')
def finetune_model(model, data_loader):
    """Finetune the model"""
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=1,
                                                   gamma=0.1)

    # Training a single epch
    train_one_epoch(model, criterion, optimizer, data_loader, "cuda", 0, 100)
Beispiel #3
0
def main(_):

    pp = pprint.PrettyPrinter()
    pp.pprint(flags.FLAGS.__flags)

    filenames = glob.glob(data_dir)

    (device, data_format) = ('/gpu:0', 'channels_first')
    if FLAGS.no_gpu or tfe.num_gpus() <= 0:
        (device, data_format) = ('/cpu:0', 'channels_last')
    print('Using device %s, and data format %s.' % (device, data_format))

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    model_objects = {
        'generator': Generator(data_format),
        'discriminator': Discriminator(data_format),
        'generator_optimizer': tf.train.AdamOptimizer(FLAGS.generator_learning_rate, FLAGS.beta1, FLAGS.beta2),
        'discriminator_optimizer': tf.train.AdamOptimizer(FLAGS.discriminator_learning_rate, FLAGS.beta1, FLAGS.beta2),
        'step_counter': tf.train.get_or_create_global_step()
    }

    summary_writer = tf.contrib.summary.create_file_writer(FLAGS.summary_dir,
                                                           flush_millis=1000)

    checkpoint = tfe.Checkpoint(**model_objects)
    checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt')
    latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    if latest_cpkt:
        print('Using latest checkpoint at ' + latest_cpkt)
    checkpoint.restore(latest_cpkt)

    dataset = tf.data.TFRecordDataset(
        filenames).map(read_and_decode_with_labels)
    dataset = dataset.shuffle(10000).apply(
        tf.contrib.data.batch_and_drop_remainder(FLAGS.batch_size))

    with tf.device(device):
        for epoch in range(FLAGS.epoch):
            start = time.time()
            with summary_writer.as_default():
                train_one_epoch(dataset=dataset, batch_size=FLAGS.batch_size, log_interval=FLAGS.log_interval,
                                z_dim=FLAGS.z_dim, device=device, epoch=epoch, **model_objects)
            end = time.time()
            checkpoint.save(checkpoint_prefix)
            print('\nTrain time for epoch #%d (step %d): %f' %
                  (checkpoint.save_counter.numpy(),
                   checkpoint.step_counter.numpy(),
                   end - start))
Beispiel #4
0
def trainer(feature_network, task_network, embedding_network, train_data_full,
            train_data1, train_data2, train_data3, test_data, epochs,
            learning_rate, eps):
    # set loss function for all NNs
    loss_function = nn.CrossEntropyLoss()

    # set optimizers for metatraining
    optimizer_feature = optim.SGD(feature_network.parameters(),
                                  lr=learning_rate,
                                  momentum=0.9)
    optimizer_task = optim.SGD(task_network.parameters(),
                               lr=learning_rate,
                               momentum=0.9)
    optimizer_embedding = optim.SGD(embedding_network.parameters(),
                                    lr=learning_rate,
                                    momentum=0.9)

    # metatraining
    for epoch in range(epochs):

        # shuffle the three source domains
        random = np.array([0, 1, 2])
        np.random.shuffle(random)
        train_domain_data = [train_data1, train_data2, train_data3]
        train_input1, train_input2, train_input3 = train_domain_data[random[
            0]], train_domain_data[random[1]], train_domain_data[random[2]]
        # train one epoch
        train_one_epoch(feature_network, task_network, embedding_network,
                        train_input1, train_input2, train_input3,
                        train_data_full, optimizer_feature, optimizer_task,
                        optimizer_embedding, eps, learning_rate, loss_function)

        # validate epoch on validation set
        loss_train, accuracy_train, loss_test, accuracy_test = validate_epoch(
            train_data_full, test_data, feature_network, task_network,
            loss_function)

        # print the metrics
        template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
        print(
            template.format(
                epoch,
                np.array2string(loss_train, precision=2, floatmode='fixed'),
                np.array2string(accuracy_train * 100,
                                precision=2,
                                floatmode='fixed'),
                np.array2string(loss_test, precision=2, floatmode='fixed'),
                np.array2string(accuracy_test * 100,
                                precision=2,
                                floatmode='fixed')))

    print('Finished Training')
Beispiel #5
0
def main():
    parser = ArgumentParser()
    parser.add_argument('--num-mixtures', default=10, type=int)
    parser.add_argument('--num-data-mixtures', default=5, type=int)
    parser.add_argument('--batch-size', default=128, type=int)
    parser.add_argument('--learning-rate', default=1e-4, type=float)
    parser.add_argument('--num-epochs', default=10, type=int)
    parser.add_argument('--hidden-dim', default=128, type=int)
    parser.add_argument('--seed', default=42, type=int)
    args = parser.parse_args()

    print('Config', args._get_kwargs())

    dataset = MixtureDataset(num_mixtures=args.num_data_mixtures)
    set_seed(args.seed)
    data_loader = DataLoader(dataset.data,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = MDN(hidden_dim=args.hidden_dim, num_mixtures=args.num_mixtures)
    model = model.to(device)
    hx = torch.randn(1, args.hidden_dim).to(device)
    hx_ = hx.repeat(args.batch_size, 1)

    optimizer = SGD(model.parameters(), lr=args.learning_rate)
    lr_scheduler = OneCycleLR(optimizer=optimizer,
                              max_lr=args.learning_rate,
                              epochs=args.num_epochs,
                              steps_per_epoch=len(data_loader))

    logger = Logger('./log')

    for epoch in range(args.num_epochs):
        train_one_epoch(epoch, model, data_loader, optimizer, lr_scheduler,
                        device, hx_, logger)
        logger.plot(f'log_{epoch}.png')
        plot_mdn_density(model, hx, dataset.data, device,
                         f'log/density_{epoch}.png')

    pred = torch.chunk(model(hx), chunks=6, dim=-1)
    weight = pred[-1]
    plt.clf()
    plt.matshow(weight.softmax(dim=-1).data.cpu().numpy(), cmap='hot')
    plt.colorbar()
    plt.savefig('./log/weight.png')
Beispiel #6
0
def main():
    train_set = SinaDataset(path.join(args.source, 'train.json'), input_dim)
    test_set = SinaDataset(path.join(args.source, 'test.json'), input_dim)
    train_loader = DataLoader(train_set,
                              batch_size=args.bs,
                              shuffle=True,
                              drop_last=True)
    test_loader = DataLoader(test_set,
                             batch_size=args.bs,
                             shuffle=True,
                             drop_last=True)

    model = TextCNN(input_dim, 200)
    # model = MyLSTM(input_dim, hidden_dim=8)
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.wd)

    epoch = 0
    train_loss = []
    train_accu = []
    valid_loss = []
    valid_accu = []
    while True:
        epoch += 1
        epoch_loss, epoch_accu = train_one_epoch(epoch, model, optimizer,
                                                 train_loader, device, args.bs)
        val_loss, val_accu = validate(model, test_loader, device, args.bs)
        train_loss += epoch_loss
        train_accu += epoch_accu
        valid_loss += val_loss
        valid_accu += val_accu

        print('saving...')
        torch.save(model.state_dict(),
                   './saved_models/epoch' + str(epoch) + '.pkl')
        print()

        if args.max_epoch and epoch >= args.max_epoch:
            train_result = {
                'batch-size': args.bs,
                'train-loss': train_loss,
                'train-accu': train_accu,
                'valid-loss': valid_loss,
                'valid-accu': valid_accu
            }
            with open('train-result.json', 'w', encoding='utf-8') as f:
                json.dump(train_result, f)

            break
def main():
    # Set logger to record information.
    logger = Logger(cfg)
    logger.log_info(cfg)
    metrics_logger = Metrics()
    utils.pack_code(cfg, logger=logger)

    # Build model.
    model = model_builder.build_model(cfg=cfg, logger=logger)

    # Read checkpoint.
    ckpt = torch.load(cfg.MODEL.PATH2CKPT) if cfg.GENERAL.RESUME else {}

    if cfg.GENERAL.RESUME:
        model.load_state_dict(ckpt["model"])
    resume_epoch = ckpt["epoch"] if cfg.GENERAL.RESUME else 0
    optimizer = ckpt[
        "optimizer"] if cfg.GENERAL.RESUME else optimizer_helper.build_optimizer(
            cfg=cfg, model=model)
    # lr_scheduler = ckpt["lr_scheduler"] if cfg.GENERAL.RESUME else lr_scheduler_helper.build_scheduler(cfg=cfg, optimizer=optimizer)
    lr_scheduler = lr_scheduler_helper.build_scheduler(cfg=cfg,
                                                       optimizer=optimizer)
    lr_scheduler.sychronize(resume_epoch)
    loss_fn = ckpt[
        "loss_fn"] if cfg.GENERAL.RESUME else loss_fn_helper.build_loss_fn(
            cfg=cfg)

    # Set device.
    model, device = utils.set_device(model, cfg.GENERAL.GPU)

    # Prepare dataset.
    if cfg.GENERAL.TRAIN:
        try:
            train_data_loader = data_loader.build_data_loader(
                cfg, cfg.DATA.DATASET, "train")
        except:
            logger.log_info("Cannot build train dataset.")
    if cfg.GENERAL.VALID:
        try:
            valid_data_loader = data_loader.build_data_loader(
                cfg, cfg.DATA.DATASET, "valid")
        except:
            logger.log_info("Cannot build valid dataset.")
    if cfg.GENERAL.TEST:
        try:
            test_data_loader = data_loader.build_data_loader(
                cfg, cfg.DATA.DATASET, "test")
        except:
            logger.log_info("Cannot build test dataset.")

    # Train, evaluate model and save checkpoint.
    for epoch in range(cfg.TRAIN.MAX_EPOCH):
        if resume_epoch >= epoch:
            continue

        try:
            train_one_epoch(
                epoch=epoch,
                cfg=cfg,
                model=model,
                data_loader=train_data_loader,
                device=device,
                loss_fn=loss_fn,
                optimizer=optimizer,
                lr_scheduler=lr_scheduler,
                metrics_logger=metrics_logger,
                logger=logger,
            )
        except:
            logger.log_info("Failed to train model.")

        optimizer.zero_grad()
        with torch.no_grad():
            utils.save_ckpt(
                path2file=os.path.join(
                    cfg.MODEL.CKPT_DIR,
                    cfg.GENERAL.ID + "_" + str(epoch).zfill(3) + ".pth"),
                logger=logger,
                model=model.state_dict(),
                epoch=epoch,
                optimizer=optimizer,
                lr_scheduler=lr_scheduler,  # NOTE Need attribdict>=0.0.5
                loss_fn=loss_fn,
                metrics=metrics_logger,
            )
        try:
            evaluate(
                epoch=epoch,
                cfg=cfg,
                model=model,
                data_loader=valid_data_loader,
                device=device,
                loss_fn=loss_fn,
                metrics_logger=metrics_logger,
                phase="valid",
                logger=logger,
                save=cfg.SAVE.SAVE,
            )
        except:
            logger.log_info("Failed to evaluate model.")

        with torch.no_grad():
            utils.save_ckpt(
                path2file=os.path.join(
                    cfg.MODEL.CKPT_DIR,
                    cfg.GENERAL.ID + "_" + str(epoch).zfill(3) + ".pth"),
                logger=logger,
                model=model.state_dict(),
                epoch=epoch,
                optimizer=optimizer,
                lr_scheduler=lr_scheduler,  # NOTE Need attribdict>=0.0.5
                loss_fn=loss_fn,
                metrics=metrics_logger,
            )

    # If test set has target images, evaluate and save them, otherwise just try to generate output images.
    if cfg.DATA.DATASET == "DualPixelNTIRE2021":
        try:
            generate(
                cfg=cfg,
                model=model,
                data_loader=valid_data_loader,
                device=device,
                phase="valid",
                logger=logger,
            )
        except:
            logger.log_info(
                "Failed to generate output images of valid set of NTIRE2021.")
    try:
        evaluate(
            epoch=epoch,
            cfg=cfg,
            model=model,
            data_loader=test_data_loader,
            device=device,
            loss_fn=loss_fn,
            metrics_logger=metrics_logger,
            phase="test",
            logger=logger,
            save=True,
        )
    except:
        logger.log_info("Failed to test model, try to generate images.")
        try:
            generate(
                cfg=cfg,
                model=model,
                data_loader=test_data_loader,
                device=device,
                phase="test",
                logger=logger,
            )
        except:
            logger.log_info("Cannot generate output images of test set.")
    return None
Beispiel #8
0
    static_noise = latent_space(16, device=device)

    metric_logger = MetricLogger('DCGAN',
                                 'MNIST',
                                 losswise_api_key=args.api_key,
                                 tensorboard=args.tensorboard)

    start_time = time.time()
    for epoch in range(cfg.NUM_EPOCHS):
        train_one_epoch(generator,
                        discriminator,
                        dataloader,
                        G_optimizer,
                        D_optimizer,
                        criterion,
                        device,
                        epoch,
                        static_noise,
                        metric_logger,
                        num_sumples=16,
                        freq=100)
        if epoch % args.save_state == 0:
            MetricLogger.checkpoint(epoch, generator, G_optimizer)
            MetricLogger.checkpoint(epoch, discriminator, D_optimizer)
    if args.save_models:
        metric_logger.save_models(generator, discriminator, cfg.NUM_EPOCHS)

    metric_logger.dump_metrics()
    metric_logger.plot_metrics()
    total_time = time.time() - start_time
    print('Training time {}'.format(total_time))
Beispiel #9
0
        print('Modeled loaded from {} with metrics:'.format(args.resume))
        print(results)
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

clock.epoch = args.start_epoch

while True:
    if clock.epoch > args.epochs:
        break

    adjust_learning_rate(optimizer, clock.epoch)

    acc, crossLoss, TVLoss = train_one_epoch(net, optimizer, ds_train,
                                             CrossEntropyCriterion, clock,
                                             args.tv)

    valacc, valcrossLoss, valTVLoss = val_one_epoch(net, ds_val,
                                                    CrossEntropyCriterion,
                                                    clock, args.tv)

    result_dir = os.path.join(benchmark_result_dir, '{}'.format(clock.epoch))

    test_on_benchmark(net, None, CrossEntropyCriterion, result_dir)

    #if clock.epoch > args.epochs *2 / 3 and clock.epoch % 2 ==1:
    #    rb = evalRob(net, ds_val)
    #    print(rb)

#with open(os.path.join(exp_dir, 'res.txt'),'w') as f:
Beispiel #10
0
def controller_3d(mode: str, focus: str):
    """
    Controller function for training and testing VNet
    on Volumetric CT images of livers with and without lesions.
    """
    print("Starting!!! :D")

    ## Load data sets
    tr_path = os.path.join(config["dstpath"], "train/")
    te_path = os.path.join(config["dstpath"], "test/")

    tr_set = LiTSDataset(tr_path, focus=focus)
    te_set = LiTSDataset(te_path, focus=focus)

    ## Initialize and load model if specified in config
    net = VNet(training=True,
               drop_rate=config["drop_rate"],
               binary_output=True)
    if config["init_model_state"] is not None:
        print("Attemt fetching of state dict at: ", config["init_model_state"])
        state_dict = torch.load(config["init_model_state"])
        net.load_state_dict(state_dict)
        print("Successfully loaded net")
    net.to(device)

    ## Test performance and store predictions if specified
    if mode == 'test':
        net.eval()

        tr_dataloader = DataLoader(tr_set)
        train_info = test_one_epoch(net,
                                    tr_dataloader,
                                    device,
                                    1,
                                    1,
                                    wandblog=False,
                                    dst_path=None)
        tr_df = pd.DataFrame(train_info)
        tr_df.to_csv(
            os.path.join(
                config["dstpath"],
                "tr_metrics_{}_{}_run{:02}.csv".format(mode, focus,
                                                       config["runid"])))

        te_dataloader = DataLoader(te_set)
        test_info = test_one_epoch(net,
                                   te_dataloader,
                                   device,
                                   1,
                                   1,
                                   wandblog=False)
        te_df = pd.DataFrame(test_info)
        te_df.to_csv(
            os.path.join(
                config["dstpath"],
                "te_metrics_{}_{}_run{:02}.csv".format(mode, focus,
                                                       config["runid"])))
        exit()

    ## Monitor process with weights and biases
    wandb.init(config=config)

    optimizer = torch.optim.Adam(params=net.parameters(),
                                 **config["optim_opts"])

    critic = TverskyLoss(**config["loss_opts"])

    ## Training loop
    for epoch in range(config["max_epochs"]):
        print(f"Epoch: {epoch}.")

        ## Dataloaders
        print("Loading data ...")
        tr_dataloader = DataLoader(tr_set, shuffle=True)
        te_dataloader = DataLoader(te_set, shuffle=True)

        epochlength = len(tr_dataloader) + len(te_dataloader)

        print("Training ...")
        train_info = train_one_epoch(net, optimizer, critic, tr_dataloader,
                                     device, epoch, epochlength)

        ## Get dice global for training
        train_dice_global = np.sum(
            train_info['train_dice_numerator']) / np.sum(
                train_info['train_dice_denominator'])
        print("Global train dice at epoch {}: {}".format(
            epoch, train_dice_global))
        wandb.log({"train_dice_global": train_dice_global})

        print("Testing ...")
        test_info = test_one_epoch(net, te_dataloader, device,
                                   epoch + len(tr_dataloader) / epochlength,
                                   epochlength)

        ## Get dice global for testing
        test_dice_global = np.sum(test_info['test_dice_numerator']) / np.sum(
            test_info['test_dice_denominator'])
        print("Global test dice at epoch {}: {}".format(
            epoch, test_dice_global))
        wandb.log({"test_dice_global": test_dice_global})

        ## Save model state at checkpoints.
        netname = str(type(net)).strip("'>").split(".")[1]
        state_dict_path = "datasets/saved_states/{}_runid_{:02}_epoch{:02}.pth".format(
            netname, config["runid"], epoch)
        if epoch % (config["checkpoint_interval"] - 1) == 0 and epoch != 0:
            torch.save(net.state_dict(), state_dict_path)
Beispiel #11
0
def main():
    # Set logger to record information.
    utils.check_env(cfg)
    logger = Logger(cfg)
    logger.log_info(cfg)
    metrics_handler = MetricsHandler(cfg.metrics)
    # utils.pack_code(cfg, logger=logger)

    # Build model.
    model = model_builder.build_model(cfg=cfg, logger=logger)
    optimizer = optimizer_helper.build_optimizer(cfg=cfg, model=model)
    lr_scheduler = lr_scheduler_helper.build_scheduler(cfg=cfg,
                                                       optimizer=optimizer)

    # Read checkpoint.
    ckpt = torch.load(cfg.model.path2ckpt) if cfg.gnrl.resume else {}
    if cfg.gnrl.resume:
        with logger.log_info(msg="Load pre-trained model.",
                             level="INFO",
                             state=True,
                             logger=logger):
            model.load_state_dict(ckpt["model"])
            optimizer.load_state_dict(ckpt["optimizer"])
            lr_scheduler.load_state_dict(ckpt["lr_scheduler"])

    # Set device.
    model, device = utils.set_pipline(
        model, cfg) if cfg.gnrl.PIPLINE else utils.set_device(
            model, cfg.gnrl.cuda)

    resume_epoch = ckpt["epoch"] if cfg.gnrl.resume else 0
    loss_fn = loss_fn_helper.build_loss_fn(cfg=cfg)

    # Prepare dataset.
    train_loaders, valid_loaders, test_loaders = dict(), dict(), dict()
    for dataset in cfg.data.datasets:
        if cfg.data[dataset].TRAIN:
            try:
                train_loaders[dataset] = data_loader.build_data_loader(
                    cfg, dataset, "train")
            except:
                utils.notify(msg="Failed to build train loader of %s" %
                             dataset)
        if cfg.data[dataset].VALID:
            try:
                valid_loaders[dataset] = data_loader.build_data_loader(
                    cfg, dataset, "valid")
            except:
                utils.notify(msg="Failed to build valid loader of %s" %
                             dataset)
        if cfg.data[dataset].TEST:
            try:
                test_loaders[dataset] = data_loader.build_data_loader(
                    cfg, dataset, "test")
            except:
                utils.notify(msg="Failed to build test loader of %s" % dataset)

    # TODO Train, evaluate model and save checkpoint.
    for epoch in range(cfg.train.max_epoch):
        epoch += 1
        if resume_epoch >= epoch:
            continue

        eval_kwargs = {
            "epoch": epoch,
            "cfg": cfg,
            "model": model,
            "loss_fn": loss_fn,
            "device": device,
            "metrics_handler": metrics_handler,
            "logger": logger,
            "save": cfg.save.save,
        }
        train_kwargs = {
            "epoch": epoch,
            "cfg": cfg,
            "model": model,
            "loss_fn": loss_fn,
            "optimizer": optimizer,
            "device": device,
            "lr_scheduler": lr_scheduler,
            "metrics_handler": metrics_handler,
            "logger": logger,
        }
        ckpt_kwargs = {
            "epoch": epoch,
            "cfg": cfg,
            "model": model.state_dict(),
            "metrics_handler": metrics_handler,
            "optimizer": optimizer.state_dict(),
            "lr_scheduler": lr_scheduler.state_dict(),
        }

        for dataset in cfg.data.datasets:
            if cfg.data[dataset].TRAIN:
                utils.notify("Train on %s" % dataset)
                train_one_epoch(data_loader=train_loaders[dataset],
                                **train_kwargs)

        utils.save_ckpt(path2file=cfg.model.path2ckpt, **ckpt_kwargs)

        if epoch in cfg.gnrl.ckphs:
            utils.save_ckpt(path2file=os.path.join(
                cfg.model.ckpts,
                cfg.gnrl.id + "_" + str(epoch).zfill(5) + ".pth"),
                            **ckpt_kwargs)
            for dataset in cfg.data.datasets:
                utils.notify("Evaluating test set of %s" % dataset,
                             logger=logger)
                if cfg.data[dataset].TEST:
                    evaluate(data_loader=test_loaders[dataset],
                             phase="test",
                             **eval_kwargs)

        for dataset in cfg.data.datasets:
            utils.notify("Evaluating valid set of %s" % dataset, logger=logger)
            if cfg.data[dataset].VALID:
                evaluate(data_loader=valid_loaders[dataset],
                         phase="valid",
                         **eval_kwargs)
    # End of train-valid for loop.

    eval_kwargs = {
        "epoch": epoch,
        "cfg": cfg,
        "model": model,
        "loss_fn": loss_fn,
        "device": device,
        "metrics_handler": metrics_handler,
        "logger": logger,
        "save": cfg.save.save,
    }

    for dataset in cfg.data.datasets:
        if cfg.data[dataset].VALID:
            utils.notify("Evaluating valid set of %s" % dataset, logger=logger)
            evaluate(data_loader=valid_loaders[dataset],
                     phase="valid",
                     **eval_kwargs)
    for dataset in cfg.data.datasets:
        if cfg.data[dataset].TEST:
            utils.notify("Evaluating test set of %s" % dataset, logger=logger)
            evaluate(data_loader=test_loaders[dataset],
                     phase="test",
                     **eval_kwargs)

    for dataset in cfg.data.datasets:
        if "train" in cfg.data[dataset].INFER:
            utils.notify("Inference on train set of %s" % dataset)
            inference(data_loader=train_loaders[dataset],
                      phase="infer_train",
                      **eval_kwargs)
        if "valid" in cfg.data[dataset].INFER:
            utils.notify("Inference on valid set of %s" % dataset)
            inference(data_loader=valid_loaders[dataset],
                      phase="infer_valid",
                      **eval_kwargs)
        if "test" in cfg.data[dataset].INFER:
            utils.notify("Inference on test set of %s" % dataset)
            inference(data_loader=test_loaders[dataset],
                      phase="infer_test",
                      **eval_kwargs)

    return None
Beispiel #12
0
# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.3)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.num_epochs*len(data_loader))
lr_scheduler = GradualWarmupScheduler(optimizer, multiplier=100, 
                                      total_epoch=min(1000, len(data_loader)-1), 
                                      after_scheduler=scheduler_cosine)

# loss function
criterion = DiceLoss()
# criterion = Weight_Soft_Dice_Loss(weight=[0.1, 0.9])
# criterion = BCELoss()
# criterion = MixedLoss(10.0, 2.0)
# criterion = Weight_BCELoss(weight_pos=0.25, weight_neg=0.75)
# criterion = Lovasz_Loss(margin=[1, 5]

print('start training...')
train_start = time.time()
for epoch in range(config.num_epochs):
    epoch_start = time.time()
    model_ft, optimizer = train_one_epoch(model_ft, data_loader, criterion, 
                                          optimizer, lr_scheduler=lr_scheduler, device=device, 
                                          epoch=epoch, vis=vis)
    do_valid(model_ft, dataloader_val, criterion, epoch, device, vis=vis)
    print('Epoch time: {:.3f}min\n'.format((time.time()-epoch_start)/60/60))

print('total train time: {}hours {}min'.format(int((time.time()-train_start)/60//60), int((time.time()-train_start)/60%60)))
inference_all(model_ft, device=device)
inference(model_ft, device=device)
torch.save(model_ft, f'{config.model}.pth')
torch.save({'optimizer': optimizer.state_dict(),
            'epoch': epoch,}, 
            'optimizer.pth')
def main(args):

    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    if args.post_training_quantize and args.distributed:
        raise RuntimeError("Post training quantization example should not be performed "
                           "on distributed mode")

    # Set backend engine to ensure that quantized model runs on the correct kernels
    if args.backend not in torch.backends.quantized.supported_engines:
        raise RuntimeError("Quantized backend not supported: " + str(args.backend))
    torch.backends.quantized.engine = args.backend

    device = torch.device(args.device)
    torch.backends.cudnn.benchmark = True

    # Data loading code
    print("Loading data")
    train_dir = os.path.join(args.data_path, 'train')
    val_dir = os.path.join(args.data_path, 'val')

    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir,
                                                                   args.cache_dataset, args.distributed)
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=args.batch_size,
        sampler=train_sampler, num_workers=args.workers, pin_memory=True)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=args.eval_batch_size,
        sampler=test_sampler, num_workers=args.workers, pin_memory=True)

    print("Creating model", args.model)
    # when training quantized models, we always start from a pre-trained fp32 reference model
    model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
    model.to(device)

    if not (args.test_only or args.post_training_quantize):
        model.fuse_model()
        model.qconfig = torch.quantization.get_default_qat_qconfig(args.backend)
        torch.quantization.prepare_qat(model, inplace=True)

        if args.distributed and args.sync_bn:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

        optimizer = torch.optim.SGD(
            model.parameters(), lr=args.lr, momentum=args.momentum,
            weight_decay=args.weight_decay)

        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       step_size=args.lr_step_size,
                                                       gamma=args.lr_gamma)

    criterion = nn.CrossEntropyLoss()
    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    if args.post_training_quantize:
        # perform calibration on a subset of the training dataset
        # for that, create a subset of the training dataset
        ds = torch.utils.data.Subset(
            dataset,
            indices=list(range(args.batch_size * args.num_calibration_batches)))
        data_loader_calibration = torch.utils.data.DataLoader(
            ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers,
            pin_memory=True)
        model.eval()
        model.fuse_model()
        model.qconfig = torch.quantization.get_default_qconfig(args.backend)
        torch.quantization.prepare(model, inplace=True)
        # Calibrate first
        print("Calibrating")
        evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
        torch.quantization.convert(model, inplace=True)
        if args.output_dir:
            print('Saving quantized model')
            if utils.is_main_process():
                torch.save(model.state_dict(), os.path.join(args.output_dir,
                           'quantized_post_train_model.pth'))
        print("Evaluating post-training quantized model")
        evaluate(model, criterion, data_loader_test, device=device)
        return

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

    model.apply(torch.quantization.enable_observer)
    model.apply(torch.quantization.enable_fake_quant)
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        print('Starting training for epoch', epoch)
        train_one_epoch(model, criterion, optimizer, data_loader, device, epoch,
                        args.print_freq)
        lr_scheduler.step()
        with torch.no_grad():
            if epoch >= args.num_observer_update_epochs:
                print('Disabling observer for subseq epochs, epoch = ', epoch)
                model.apply(torch.quantization.disable_observer)
            if epoch >= args.num_batch_norm_update_epochs:
                print('Freezing BN for subseq epochs, epoch = ', epoch)
                model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
            print('Evaluate QAT model')

            evaluate(model, criterion, data_loader_test, device=device)
            quantized_eval_model = copy.deepcopy(model_without_ddp)
            quantized_eval_model.eval()
            quantized_eval_model.to(torch.device('cpu'))
            torch.quantization.convert(quantized_eval_model, inplace=True)

            print('Evaluate Quantized model')
            evaluate(quantized_eval_model, criterion, data_loader_test,
                     device=torch.device('cpu'))

        model.train()

        if args.output_dir:
            checkpoint = {
                'model': model_without_ddp.state_dict(),
                'eval_model': quantized_eval_model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args}
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'checkpoint.pth'))
        print('Saving models after epoch ', epoch)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
def train_fold(fold_idx, work_dir, train_filenames, test_filenames,
               batch_sampler, epoch, epochs_to_train):
    os.makedirs(work_dir, exist_ok=True)
    fold_logger = kfold.FoldLogger(work_dir)
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    # device = 'cpu'
    batch_size = 4

    # model = models.UNet(6, 1)
    # model = models.MyResNetModel()
    model = models.ResNetUNet(n_classes=1, upsample=True)
    # model = models.ResNetUNetPlusPlus(n_classes=1)
    # model = models.EfficientUNet(n_classes=1)

    # model = models.HRNetWithClassifier()

    model.to(device)
    model = torch.nn.DataParallel(model)
    # model.to(device)

    data_patallel_multiplier = max(1, torch.cuda.device_count())
    # data_patallel_multiplier = 1
    print('data_parallel_multiplier =', data_patallel_multiplier)

    img_size = 1024

    train_dataset = datareader.SIIMDataset('data/dicom-images-train',
                                           'data/train-rle.csv',
                                           ([img_size], [img_size]),
                                           augment=True,
                                           filenames_whitelist=train_filenames)
    # if batch_sampler is None:
    #     batch_sampler = samplers.OnlineHardBatchSampler(train_dataset, batch_size * data_patallel_multiplier,
    #                                                    drop_last=False)
    # train_dataloader = torch.utils.data.DataLoader(train_dataset, num_workers=os.cpu_count(),
    #                                                batch_sampler=batch_sampler)
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=batch_size *
                                                   data_patallel_multiplier,
                                                   shuffle=True,
                                                   num_workers=os.cpu_count())

    val_dataset = datareader.SIIMDataset('data/dicom-images-train',
                                         'data/train-rle.csv',
                                         ([img_size], [img_size]),
                                         filenames_whitelist=test_filenames)
    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=batch_size *
                                                 data_patallel_multiplier,
                                                 shuffle=False,
                                                 num_workers=os.cpu_count())

    trainable_params = [
        param for param in model.parameters() if param.requires_grad
    ]

    lr_scaling_coefficient = (1 /
                              16) * data_patallel_multiplier * batch_size / 10
    # max_lr = 2e-3 * lr_scaling_coefficient
    # base_lr = 5e-5 * lr_scaling_coefficient

    # OHEM Limited loss works with that divided by 10
    max_lr = 2.5e-4 * lr_scaling_coefficient
    base_lr = 3.5e-5 * lr_scaling_coefficient

    # optim = torch.optim.Adam(params=trainable_params, lr=base_lr, betas=(0.0, 0.9))
    # optim = torch.optim.Adam(params=[
    #     {"params": backbone_parameters, "lr": base_lr},
    #     {"params": head_and_classifier_params, "lr": max_lr}], lr=base_lr)
    # optim = torch.optim.Adam(params=trainable_params, lr=max_lr)
    # optim = torch.optim.AdamW(params=trainable_params, lr=base_lr, weight_decay=0.00001)
    optim = radam.RAdam(params=trainable_params,
                        lr=base_lr,
                        weight_decay=0.0001)
    optim = torchcontrib.optim.SWA(optim)
    # optim = torch.optim.SGD(params=trainable_params,
    #                         momentum=0.98,
    #                         nesterov=True,
    #                         lr=base_lr)
    # optim = torch.optim.SGD(params=trainable_params,
    #                         momentum=0.9,
    #                         nesterov=True,
    #                         lr=base_lr)

    best_metric = 0.0
    _, loaded_best_metric = utils.try_load_checkpoint(work_dir,
                                                      model,
                                                      device,
                                                      optimizer=optim,
                                                      load_optimizer=True)
    if loaded_best_metric is not None: best_metric = loaded_best_metric

    # Experiments show that it often is good to set stepsize equal to 2 − 10 times the number of iterations in an epoch.
    # For example, setting stepsize = 8 ∗ epoch with the CIFAR-10 training run(as shown in Figure 1) only gives slightly
    # better results than setting stepsize = 2 ∗ epoch. (https://arxiv.org/pdf/1506.01186.pdf)
    # cycle_len = 4 == stepsize = 2
    # in my implementation
    epochs_per_cycle = 20
    lr_scheduler = lr_utils.CyclicalLR(max_lr=max_lr,
                                       base_lr=base_lr,
                                       steps_per_epoch=len(train_dataloader),
                                       epochs_per_cycle=epochs_per_cycle,
                                       mode='cosine')
    lr_scheduler.step_value = epoch * len(train_dataloader)

    steps_per_epoch = len(train_dataloader)
    # torch.optim.lr_scheduler.CyclicLR(optimizer=optim, base_lr=base_lr, max_lr=max_lr, step_size_up=steps_per_epoch * 1,
    #                                   step_size_down=steps_per_epoch * 4, mode='triangular', gamma=1.0, scale_fn=None,
    #                                   scale_mode='cycle',
    #                                   cycle_momentum=False, base_momentum=0.8, max_momentum=0.9,
    #                                   last_epoch=-1)

    # model, optimizer = amp.initialize(model, optim, opt_level='O0')

    writer = SummaryWriter(work_dir)
    for i in range(epochs_to_train):
        train_result_dict = train_one_epoch(model=model,
                                            optimizer=optim,
                                            data_loader=train_dataloader,
                                            device=device,
                                            epoch=epoch,
                                            lr_scheduler=lr_scheduler,
                                            summary_writer=writer,
                                            print_freq=100)

        val_result_dict = validate.validate(model, val_dataloader, device)
        mask_thresh, mask_score = val_result_dict['best_mask_score']
        class_thresh, class_score = val_result_dict['best_class_score']
        global_step = epoch * len(train_dataloader)
        writer.add_scalar('dice', mask_score, global_step=global_step)
        writer.add_scalar('classification_accuracy',
                          class_score,
                          global_step=global_step)
        writer.add_scalar('mean_epoch_loss',
                          train_result_dict['loss'],
                          global_step=global_step)
        writer.add_scalar('epoch', epoch, global_step=global_step)

        # {'best_mask_score': best_mask_score, 'mean_mask_scores': mean_mask_scores,
        #  'best_class_score': best_class_score, 'mean_class_scores': mean_class_scores}
        log_data = {
            'score': val_result_dict['best_mask_score'][1],
            'mask_threshold': val_result_dict['best_mask_score'][0],
            'class_accuracy': val_result_dict['best_class_score'][1],
            'class_thresold': val_result_dict['best_class_score'][0]
        }
        if (epoch + 1) % epochs_per_cycle == 0 and epoch != 0:
            print('Updating SWA running average')
            optim.update_swa()
        epoch += 1
        break

    # if mask_score > best_metric:
    #     best_metric = mask_score
    # if epoch % epochs_per_cycle == 0:
    fold_logger.log_epoch(epoch - 1, log_data)
    utils.save_checkpoint(output_dir=work_dir,
                          epoch=epoch - 1,
                          model=model,
                          optimizer=optim,
                          best_metric=best_metric)

    if (epoch) % epochs_per_cycle == 0 and epoch != 0:
        optim.swap_swa_sgd()
        print('Swapped SWA buffers')
        print('Updating BatchNorm statistics...')
        optim.bn_update(
            utils.dataloader_image_extract_wrapper(train_dataloader), model,
            device)
        print('Updated BatchNorm statistics')
        print('Validating SWA model...')
        val_result_dict = validate.validate(model, val_dataloader, device)
        log_data = {
            'score': val_result_dict['best_mask_score'][1],
            'mask_threshold': val_result_dict['best_mask_score'][0],
            'class_accuracy': val_result_dict['best_class_score'][1],
            'class_thresold': val_result_dict['best_class_score'][0]
        }
        fold_logger.log_epoch('swa', log_data)
        print('Saved SWA model')
        utils.save_checkpoint(output_dir=work_dir,
                              epoch=None,
                              name='swa',
                              model=model,
                              optimizer=optim,
                              best_metric=best_metric)

    return {
        'mask_score': mask_score,
        'class_score': class_score,
        'global_step': global_step,
        'batch_sampler': batch_sampler
    }
            multimode_model=multimode_model)
    else:
        checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                         meta_model=meta_model,
                                         counts_model=counts_model,
                                         multimode_model=multimode_model)

    best_val_acc = 0.00
    for epoch in range(NUM_EPOCHS):
        start = time.time()
        loss, accuracy = train.train_one_epoch(
            train_df,
            BATCH_SIZE,
            optimizer,
            text_encoder,
            meta_model,
            counts_model,
            multimode_model,
            train_text_encoder=TRAIN_TEXT_ENCODER,
            repeat_first_batch=REPEAT_FIRST_BATCH,
            binary_classification=BINARY_CLASSIFICATION)
        print(f'Epoch: {epoch + 1}, '
              f'train_loss: {loss}, train_accuracy: {accuracy}')
        val_loss, val_accuracy = train.predict_one_epoch(
            val_df,
            BATCH_SIZE,
            text_encoder,
            meta_model,
            counts_model,
            multimode_model,
            binary_classification=BINARY_CLASSIFICATION)
Beispiel #16
0
        G.parameters(),
        lr=learning_rate)  #Not sure what the parameters do, just copying it

    if opts.resume:
        g_checkpoint = torch.load(
            load_path, map_location=torch.device(device))  #Load from
        G.load_state_dict(g_checkpoint['model'])
        optimizer.load_state_dict(g_checkpoint['optimizer'])

    if opts.mode == "train":
        G = G.train()
        current_iter = 0
        for epoch in range(opts.epochs):
            #Put config as argument
            current_iter = train_one_epoch(G, optimizer, dataset, device,
                                           save_dir, current_iter, epoch,
                                           opts.write)
    elif opts.mode == "test":
        g_checkpoint = torch.load(
            load_path, map_location=torch.device(device))  #Load from
        G.load_state_dict(g_checkpoint['model'])
        print("Finished loading")
        G = G.eval()

        wav_folder = pickle.load(open('./data/data.pkl', "rb"))
        uttr_org = wav_folder[2][2]
        uttr_trg = wav_folder[2][3]
        spect_vc = conversion.convert_two(G, uttr_org, uttr_trg)

        with open('./result_pkl/fin_conv.pkl', 'wb+') as handle:
            pickle.dump(spect_vc, handle)
Beispiel #17
0
te_clean_list = []
te_adv_list = []
_te = 0

# 开始训练部分,达到指定epoch退出
while True:
    now_epoch = now_epoch + 1
    if now_epoch > args.epochs:
        break

    # logger输出
    descrip_str = 'Training epoch:{}/{} -- lr:{}'.format(now_epoch, args.epochs, lr_scheduler.get_lr()[0])
    logger.info(f'now_epoch: {now_epoch:.1f}, lr: {lr_scheduler.get_lr()[0]:.2f}')

    # 调用训练函数
    train_one_epoch(net, train_loader, optimizer, criterion, device, descrip_str, train_attack, adv_coef=args.adv_coef,
                    logger=logger)

    # 训练一定epoch后进行eval
    if args.val_interval > 0 and now_epoch % args.val_interval == 0:
        te_clean, te_adv = eval_one_epoch(net, val_loader, device, val_attack, logger=logger)
        te_clean_list.append(te_clean)
        te_adv_list.append(te_adv)
        _te += 1

    # 更新学习率
    lr_scheduler.step()

    # 保存checkpoint
    save_checkpoint(now_epoch, net, optimizer, lr_scheduler,
                    file_name=os.path.join(args.model_dir, 'epoch-{}.checkpoint'.format(now_epoch)))
Beispiel #18
0
def controller_2d(config):  ## (cfg)
    """
    Controller function for training and testing VNet
    on Volumetric CT images of livers with and without lesions.
    """
    print("Starting...")

    ## Reproducibility
    if config.seed is not None:
        utils.ensure_reproducibility(config.seed)

    ## Load data
    full_dataset = LiTSDataset2d(config.dst_2d_path,
                                 focus=config.focus,
                                 data_limit=config.data_limit)
    workers = config.num_workers

    ## Split data into train and test
    train_proportion = config.train_proportion
    len_train = int(len(full_dataset) * train_proportion)
    len_test = len(full_dataset) - len_train
    tr_set, te_set = torch.utils.data.random_split(full_dataset,
                                                   (len_train, len_test))

    ## Init and load model if specified in config
    net = VNet2dAsDrawn(drop_rate=config.drop_rate)
    # net = DeepVNet2d(drop_rate=config.drop_rate)

    ## Load model if specified in config
    print(config.init_2d_model_state)
    if config.init_2d_model_state is not None:
        print("Attemt fetching of state dict at: ", config.init_2d_model_state)
        state_dict = torch.load(config.init_2d_model_state)["model_state_dict"]
        net.load_state_dict(state_dict)
        print("Successfully loaded net")
    net.to(device)

    ## Make summary
    summary(net, (1, 512, 512))

    ## If only testing, run through net, get resulting metrics and store predictions.
    if config.mode == 'test':
        net.eval()

        tr_dataloader = DataLoader(tr_set,
                                   num_workers=workers,
                                   pin_memory=True)
        if not os.path.exists(config[f"dst2d_pred_{config.focus}_path"]):
            os.mkdir(config[f"dst2d_pred_{config.focus}_path"])
        train_info = test_one_epoch(
            net,
            tr_dataloader,
            device,
            1,
            1,
            wandblog=False,
            dst_path=config[f"dst2d_pred_{config.focus}_path"])
        tr_df = pd.DataFrame(train_info, index=[0])
        tr_df.to_csv(
            os.path.join(
                config.dst_2d_path,
                "tr_metrics_{}_{}_run{:02}.csv".format(config.label_type,
                                                       config.focus,
                                                       config.runid)))

        te_dataloader = DataLoader(te_set,
                                   num_workers=workers,
                                   pin_memory=True)
        test_info = test_one_epoch(
            net,
            te_dataloader,
            device,
            1,
            1,
            wandblog=False,
            dst_path=config[f"dst2d_pred_{config.focus}_path"])
        te_df = pd.DataFrame(test_info, index=[0])
        te_df.to_csv(
            os.path.join(
                config.dst_2d_path,
                "te_metrics_{}_{}_run{:02}.csv".format(config.label_type,
                                                       config.focus,
                                                       config.runid)))
        exit()

    ## Monitor process with weights and biases
    wandb.init(config=config)

    ## Optimizer
    optimizer = torch.optim.Adam(params=net.parameters(), **config.optim_opts)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                       config.lr_decay_rate,
                                                       last_epoch=-1)

    ## Loss
    if config.label_type == 'segmentation':
        critic = DiceLoss(**config["loss_opts"])
        # critic = TverskyLoss(**config["loss_opts"])
    elif config.label_type == 'pixelcount':
        critic = MSEPixelCountLoss(**config.loss_opts)

    for epoch in range(config.max_epochs):
        print(f"Epoch: {epoch}.")

        ## Make dataloaders
        print("Loading data ...")
        tr_dataloader = DataLoader(tr_set,
                                   batch_size=config.batch_size,
                                   shuffle=True,
                                   num_workers=workers,
                                   pin_memory=True)
        te_dataloader = DataLoader(te_set,
                                   batch_size=config.batch_size,
                                   shuffle=True,
                                   num_workers=workers,
                                   pin_memory=True)

        epochlength = len(tr_dataloader) + len(te_dataloader)

        print("Training ...")
        train_info = train_one_epoch(net, optimizer, critic, tr_dataloader,
                                     device, epoch, epochlength)

        print("Testing ...")
        test_info = test_one_epoch(net, te_dataloader, device, epoch,
                                   epochlength)

        scheduler.step()

        ## Checkpoint storage (with prediction exmple)
        netname = str(type(net)).strip("'>").split(".")[1]
        saved_states_folder = os.path.join(
            "datasets/saved_states/runid_{:03}/".format(config.runid))
        if epoch % (config.checkpoint_interval - 1) == 0:  # and epoch != 0:
            if not os.path.exists(saved_states_folder):
                os.mkdir(saved_states_folder)
            state_name = "{}_runid_{:02}_epoch{:02}.pth".format(
                netname, config.runid, epoch)
            state_dict_path = os.path.join(saved_states_folder, state_name)

            torch.save(
                {
                    "model_state_dict": net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "epoch": epoch,
                    "loss": np.mean(train_info["loss"]),  ## mean loss of epoch
                    "config": config,
                },
                state_dict_path)

            ## Store example prediction
            # imageidx = torch.randint(0, len(te_set), (1,))
            # ex_loader = DataLoader(torch.utils.data.Subset(te_set, imageidx),
            #                        num_workers=workers, pin_memory=True)
            # test_one_epoch(net, ex_loader, device, epoch, epochlength,
            #                wandblog=False, dst_format='npy', dst_path=config.dst2d_fig_path)
        print()
Beispiel #19
0
def main(args):
    if args.prototype and prototype is None:
        raise ImportError(
            "The prototype module couldn't be found. Please install the latest torchvision nightly."
        )
    if not args.prototype and args.weights:
        raise ValueError(
            "The weights parameter works only in prototype mode. Please pass the --prototype argument."
        )
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    if args.post_training_quantize and args.distributed:
        raise RuntimeError(
            "Post training quantization example should not be performed on distributed mode"
        )

    # Set backend engine to ensure that quantized model runs on the correct kernels
    if args.backend not in torch.backends.quantized.supported_engines:
        raise RuntimeError("Quantized backend not supported: " +
                           str(args.backend))
    torch.backends.quantized.engine = args.backend

    device = torch.device(args.device)
    torch.backends.cudnn.benchmark = True

    # Data loading code
    print("Loading data")
    train_dir = os.path.join(args.data_path, "train")
    val_dir = os.path.join(args.data_path, "val")

    dataset, dataset_test, train_sampler, test_sampler = load_data(
        train_dir, val_dir, args)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=args.batch_size,
                                              sampler=train_sampler,
                                              num_workers=args.workers,
                                              pin_memory=True)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=args.eval_batch_size,
        sampler=test_sampler,
        num_workers=args.workers,
        pin_memory=True)

    print("Creating model", args.model)
    # when training quantized models, we always start from a pre-trained fp32 reference model
    if not args.prototype:
        model = torchvision.models.quantization.__dict__[args.model](
            pretrained=True, quantize=args.test_only)
    else:
        model = prototype.models.quantization.__dict__[args.model](
            weights=args.weights, quantize=args.test_only)
    model.to(device)

    if not (args.test_only or args.post_training_quantize):
        model.fuse_model(is_qat=True)
        model.qconfig = torch.ao.quantization.get_default_qat_qconfig(
            args.backend)
        torch.ao.quantization.prepare_qat(model, inplace=True)

        if args.distributed and args.sync_bn:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)

    criterion = nn.CrossEntropyLoss()
    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        args.start_epoch = checkpoint["epoch"] + 1

    if args.post_training_quantize:
        # perform calibration on a subset of the training dataset
        # for that, create a subset of the training dataset
        ds = torch.utils.data.Subset(dataset,
                                     indices=list(
                                         range(args.batch_size *
                                               args.num_calibration_batches)))
        data_loader_calibration = torch.utils.data.DataLoader(
            ds,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True)
        model.eval()
        model.fuse_model(is_qat=False)
        model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
        torch.ao.quantization.prepare(model, inplace=True)
        # Calibrate first
        print("Calibrating")
        evaluate(model,
                 criterion,
                 data_loader_calibration,
                 device=device,
                 print_freq=1)
        torch.ao.quantization.convert(model, inplace=True)
        if args.output_dir:
            print("Saving quantized model")
            if utils.is_main_process():
                torch.save(
                    model.state_dict(),
                    os.path.join(args.output_dir,
                                 "quantized_post_train_model.pth"))
        print("Evaluating post-training quantized model")
        evaluate(model, criterion, data_loader_test, device=device)
        return

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

    model.apply(torch.ao.quantization.enable_observer)
    model.apply(torch.ao.quantization.enable_fake_quant)
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        print("Starting training for epoch", epoch)
        train_one_epoch(model, criterion, optimizer, data_loader, device,
                        epoch, args)
        lr_scheduler.step()
        with torch.inference_mode():
            if epoch >= args.num_observer_update_epochs:
                print("Disabling observer for subseq epochs, epoch = ", epoch)
                model.apply(torch.ao.quantization.disable_observer)
            if epoch >= args.num_batch_norm_update_epochs:
                print("Freezing BN for subseq epochs, epoch = ", epoch)
                model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
            print("Evaluate QAT model")

            evaluate(model,
                     criterion,
                     data_loader_test,
                     device=device,
                     log_suffix="QAT")
            quantized_eval_model = copy.deepcopy(model_without_ddp)
            quantized_eval_model.eval()
            quantized_eval_model.to(torch.device("cpu"))
            torch.ao.quantization.convert(quantized_eval_model, inplace=True)

            print("Evaluate Quantized model")
            evaluate(quantized_eval_model,
                     criterion,
                     data_loader_test,
                     device=torch.device("cpu"))

        model.train()

        if args.output_dir:
            checkpoint = {
                "model": model_without_ddp.state_dict(),
                "eval_model": quantized_eval_model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir,
                                         f"model_{epoch}.pth"))
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
        print("Saving models after epoch ", epoch)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f"Training time {total_time_str}")
                'resnet.layer4.1.conv2', 'resnet.layer4.2.conv1', 'resnet.layer4.2.conv2',\
                'center.0.0', 'center.1.0',  'decoder4.squeeze.0', 'decoder3.squeeze.0', 'decoder2.squeeze.0','decoder1.squeeze.0', \
            ]
        }]
                     
    # Prune model and test accuracy without fine tuning.
    # print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
    optimizer_finetune = optimizer
    pruner = L1FilterPruner(model, configure_list, optimizer_finetune)
    model = pruner.compress()

    # Code for fots training
    train_folder_syn = args.train_folder_syn
    train_folder_sample = args.train_folder_sample
    output_path = args.save_dir
    data_set = datasets.MergeText(train_folder_syn, train_folder_sample, datasets.transform, train=True)
    dl = torch.utils.data.DataLoader(data_set, batch_size=args.batch_size, shuffle=True,
                                         sampler=None, batch_sampler=None, num_workers=args.num_workers)
    dl_val = None
    if args.val:
        data_set_val = datasets.MergeText(train_folder_syn, train_folder_sample, datasets.transform, train=False)
        dl_val = torch.utils.data.DataLoader(data_set_val, batch_size=1, shuffle=True,
                                                 sampler=None, batch_sampler=None, num_workers=args.num_workers)        
    max_batches_per_iter_cnt = 2

    for epoch in range(50):
        pruner.update_epoch(epoch)
        print('# Epoch {} #'.format(epoch))
        val_loss = train_one_epoch(model, detection_loss, optimizer, lr_scheduler, max_batches_per_iter_cnt, dl, dl_val, epoch)
        pruner.export_model(model_path='{}/pruned_fots{}.pth'.format(output_path, epoch), mask_path='{}/mask_fots{}.pth'.format(output_path, epoch))
def main(cmdline_args):
    parser = get_parser()
    args = parser.parse_args(cmdline_args)
    print(parser.description)
    print(args)

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    ## Prepare the pretrained model and data loaders
    model, data_loader_train, data_loader_test, data_loader_onnx = prepare_model(
        args.model_name, args.data_dir, not args.disable_pcq,
        args.batch_size_train, args.batch_size_test, args.batch_size_onnx,
        args.calibrator, args.pretrained, args.ckpt_path, args.ckpt_url)

    ## Initial accuracy evaluation
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        print('Initial evaluation:')
        top1_initial = evaluate(model,
                                criterion,
                                data_loader_test,
                                device="cuda",
                                print_freq=args.print_freq)

    ## Calibrate the model
    with torch.no_grad():
        calibrate_model(model=model,
                        model_name=args.model_name,
                        data_loader=data_loader_train,
                        num_calib_batch=args.num_calib_batch,
                        calibrator=args.calibrator,
                        hist_percentile=args.percentile,
                        out_dir=args.out_dir)

    ## Evaluate after calibration
    if args.num_calib_batch > 0:
        with torch.no_grad():
            print('Calibration evaluation:')
            top1_calibrated = evaluate(model,
                                       criterion,
                                       data_loader_test,
                                       device="cuda",
                                       print_freq=args.print_freq)
    else:
        top1_calibrated = -1.0

    ## Build sensitivy profile
    if args.sensitivity:
        build_sensitivity_profile(model, criterion, data_loader_test)

    ## Finetune the model
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, args.num_finetune_epochs)
    for epoch in range(args.num_finetune_epochs):
        # Training a single epch
        train_one_epoch(model, criterion, optimizer, data_loader_train, "cuda",
                        0, 100)
        lr_scheduler.step()

    if args.num_finetune_epochs > 0:
        ## Evaluate after finetuning
        with torch.no_grad():
            print('Finetune evaluation:')
            top1_finetuned = evaluate(model,
                                      criterion,
                                      data_loader_test,
                                      device="cuda")
    else:
        top1_finetuned = -1.0

    ## Export to ONNX
    onnx_filename = args.out_dir + '/' + args.model_name + ".onnx"
    top1_onnx = -1.0
    if export_onnx(model, onnx_filename, args.batch_size_onnx,
                   not args.disable_pcq) and args.evaluate_onnx:
        ## Validate ONNX and evaluate
        top1_onnx = evaluate_onnx(onnx_filename, data_loader_onnx, criterion,
                                  args.print_freq)

    ## Print summary
    print("Accuracy summary:")
    table = PrettyTable(['Stage', 'Top1'])
    table.align['Stage'] = "l"
    table.add_row(['Initial', "{:.2f}".format(top1_initial)])
    table.add_row(['Calibrated', "{:.2f}".format(top1_calibrated)])
    table.add_row(['Finetuned', "{:.2f}".format(top1_finetuned)])
    table.add_row(['ONNX', "{:.2f}".format(top1_onnx)])
    print(table)

    ## Compare results
    if args.threshold >= 0.0:
        if args.evaluate_onnx and top1_onnx < 0.0:
            print("Failed to export/evaluate ONNX!")
            return 1
        if args.num_finetune_epochs > 0:
            if top1_finetuned >= (top1_onnx - args.threshold):
                print("Accuracy threshold was met!")
            else:
                print("Accuracy threshold was missed!")
                return 1

    return 0
Beispiel #22
0
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

tensorboard_dir = os.path.join(hp['LOGS_DIR'], hp['MODEL_NAME'])
print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
if not os.path.exists(tensorboard_dir):
    os.makedirs(tensorboard_dir)
configure(tensorboard_dir)

best_valid_acc, patience_counter = 0, 0

for epoch in range(0, hp['EPOCHS']):
    print('\nEpoch: {}/{} - LR: {:.6f}'.format(epoch + 1, hp['EPOCHS'],
                                               hp['LEARNING_RATE']))

    # train for 1 epoch
    train_loss, train_acc = train_one_epoch(model, optimizer, train_loader,
                                            epoch, hp)

    # evaluate on validation set
    valid_loss, valid_acc = validate(model, valid_loader, epoch, hp)

    # # reduce lr if validation loss plateaus
    # self.scheduler.step(valid_loss)

    is_best = valid_acc > best_valid_acc
    msg1 = "train loss: {:.3f} - train acc: {:.3f} "
    msg2 = "- val loss: {:.3f} - val acc: {:.3f}"
    if is_best:
        patience_counter = 0
        msg2 += " [*]"
    msg = msg1 + msg2
    print(msg.format(train_loss, train_acc, valid_loss, valid_acc))
Beispiel #23
0
    model.load_state_dict(torch.load(Config.model_load))

# optimizer
optimizer = NoamOpt(
    Config.hidden_size, Config.factor, Config.warmup,
    Adam(model.parameters(), lr=Config.lr, betas=(0.9, 0.98), eps=1e-9))

# criterion
criterion = LabelSmoothingLoss(0.1,
                               tgt_vocab_size=output_lang.n_words,
                               ignore_index=Config.PADDING_token).cuda()
#criterion = nn.NLLLoss()

# make ckpts save
if not os.path.exists('ckpts'):
    os.makedirs('ckpts')

# training
best_bleu = -1
for i in range(Config.n_epoch):
    train_one_epoch(train_loader, model, optimizer, criterion, print_every=50)
    if i % 5 == 0:
        evaluateRandomly(pairs_dev, input_lang, output_lang, model, n=3)
    acc, bleu = evaluate_dataset(dev_loader, model, output_lang)
    print("accuracy: {}  bleu score: {}".format(acc, bleu))

    # save model if best
    if best_bleu < bleu:
        best_bleu = bleu
        torch.save(model.state_dict(), "ckpts/transformer_{}.pt".format(i))
Beispiel #24
0
        checkpoint = torch.load(os.path.join(args.root, 'save/checkpoints',
                                             args.checkpoint),
                                map_location=device)
        global_model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        start_ep = checkpoint['epoch']

    print(args.epochs, ' epochs of training starts:')
    lines = ['Options:', str(args)]
    for epoch in tqdm(range(start_ep, args.epochs)):
        logger = train_one_epoch(global_model,
                                 criterion,
                                 optimizer,
                                 train_loader,
                                 lr_scheduler,
                                 device,
                                 epoch,
                                 print_freq=1000,
                                 args=args)
        lines.append(logger)
        lr_scheduler.step()
        if epoch % args.save_frequency == 0 or epoch == args.epochs - 1:
            torch.save(
                {
                    'model': global_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch + 1,
                    'model_name': args.model
                },
Beispiel #25
0
def controller_2d():  ## (cfg)
    """
    Controller function for training and testing VNet
    on Volumetric CT images of livers with and without lesions.
    """
    print("Starting...")

    ## Load data
    full_dataset = LiTSDataset2d(config["dst_2d_path"],
                                 focus=config["focus"],
                                 data_limit=config["data_limit"])
    workers = config["num_workers"]

    ## Split data into train and test
    train_proportion = config["train_proportion"]
    len_train = int(len(full_dataset) * train_proportion)
    len_test = len(full_dataset) - len_train
    tr_set, te_set = random_split(full_dataset, (len_train, len_test))

    ## Init and load model if specified in config
    net = VNet2dAsDrawn(drop_rate=config["drop_rate"])
    # net = TestNet()
    # net = UNet(drop_rate=config["drop_rate"])
    # net = VNet2d(drop_rate=config["drop_rate"])
    # net = VGG()  ## Too big to run
    # net = DeepVNet2d(drop_rate=config["drop_rate"])

    ## Load model if specified in config
    print(config["init_2d_model_state"])
    if config["init_2d_model_state"] is not None:
        print("Attemt fetching of state dict at: ",
              config["init_2d_model_state"])
        state_dict = torch.load(config["init_2d_model_state"])[
            "model_state_dict"]  # , map_location=torch.device("cpu")
        net.load_state_dict(state_dict)
        print("Successfully loaded net")
    net.to(device)

    ## Make summary
    summary(net, (1, 512, 512))

    ## If only testing, run through net, get resulting metrics and store predictions.
    if config["mode"] == 'test':
        net.eval()

        # ## Small subset
        # n = 50
        # tr_set, _ = random_split(full_dataset, (n, len(full_dataset)-n))

        ## Training set
        tr_dataloader = DataLoader(tr_set,
                                   num_workers=workers,
                                   pin_memory=True)
        if not os.path.exists(config[f"dst2d_pred_{config['focus']}_path"]):
            os.mkdir(config[f"dst2d_pred_{config['focus']}_path"])
        train_info = test_one_epoch(
            net,
            tr_dataloader,
            device,
            1,
            1,
            wandblog=False,
            dst_path=config[f"dst2d_pred_{config['focus']}_path"])
        tr_df = pd.DataFrame(train_info, index=[0])
        tr_df.to_csv(
            os.path.join(
                config["dst_2d_path"],
                "tr_metrics_{}_{}_run{:02}.csv".format(config["label_type"],
                                                       config["focus"],
                                                       config["runid"])))

        ## Test set
        # te_dataloader = DataLoader(te_set, num_workers=workers, pin_memory=True)
        # test_info = test_one_epoch(net, te_dataloader, device,
        #                            1, 1, wandblog=False, dst_path=config[f"dst2d_pred_{config['focus']}_path"])
        # te_df = pd.DataFrame(test_info, index=[0])
        # te_df.to_csv(os.path.join(config["dst_2d_path"],
        #                           "te_metrics_{}_{}_run{:02}.csv".format(config["label_type"], config["focus"], config["runid"])))
        exit()

    ## Monitor process with weights and biases
    wandb.init(config=config)

    ## Optimizer
    optimizer = torch.optim.Adam(params=net.parameters(),
                                 **config["optim_opts"])
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                       config["lr_decay_rate"],
                                                       last_epoch=-1)

    ## Loss
    if config["label_type"] == 'segmentation':
        critic = DiceLoss(**config["loss_opts"])
        # critic = TverskyLoss(**config["loss_opts"])
    elif config["label_type"] == 'pixelcount':
        critic = MSEPixelCountLoss(**config["loss_opts"])
    elif config["label_type"] == 'binary':
        critic = torch.nn.BCELoss()
        # critic = WeightedBCELoss(weights=[0.8, 0.2])

    ## Training loop
    for epoch in range(config["max_epochs"]):
        print(f"Epoch: {epoch}.")

        ## Make dataloaders
        print("Loading data ...")
        tr_dataloader = DataLoader(tr_set,
                                   batch_size=config["batch_size"],
                                   shuffle=True,
                                   num_workers=workers,
                                   pin_memory=True)
        te_dataloader = DataLoader(te_set,
                                   batch_size=config["batch_size"],
                                   shuffle=True,
                                   num_workers=workers,
                                   pin_memory=True)

        epochlength = len(tr_dataloader) + len(te_dataloader)

        print("Training ...")
        train_info = train_one_epoch(net, optimizer, critic, tr_dataloader,
                                     device, epoch, epochlength)

        print("Testing ...")
        test_info = test_one_epoch(net, te_dataloader, device, epoch,
                                   epochlength)

        scheduler.step()

        ## Checkpoint storage (with prediction exmple)
        netname = str(type(net)).strip("'>").split(".")[1]
        saved_states_folder = os.path.join(
            "datasets/saved_states/runid_{:03}/".format(config["runid"]))
        if epoch % (config["checkpoint_interval"] - 1) == 0:  # and epoch != 0:
            if not os.path.exists(saved_states_folder):
                os.mkdir(saved_states_folder)
            state_name = "{}_runid_{:02}_epoch{:02}.pth".format(
                netname, config["runid"], epoch)
            state_dict_path = os.path.join(saved_states_folder, state_name)

            torch.save(
                {
                    "model_state_dict": net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "epoch": epoch,
                    "loss": np.mean(train_info["loss"]),  ## mean loss of epoch
                    "config": config,
                },
                state_dict_path)

            ## Store example prediction
            imageidx = torch.randint(0, len(te_set), (1, ))
            ex_loader = DataLoader(torch.utils.data.Subset(te_set, imageidx),
                                   num_workers=workers,
                                   pin_memory=True)
            test_one_epoch(net,
                           ex_loader,
                           device,
                           epoch,
                           epochlength,
                           wandblog=True,
                           dst_format='png',
                           dst_path=config["dst2d_fig_path"])
            print("done!")
        print()