Ejemplo n.º 1
0
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", type=str, required=True)
    parser.add_argument("--data_directory",
                        type=str,
                        default="/home/samueld/mrp_update/mrp")
    args = parser.parse_args()

    checkpoint = torch.load(args.checkpoint)
    args = Params().load_state_dict(checkpoint["args"]).init_data_paths(
        args.data_directory)
    args.log_wandb = False

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    directory = initialize(args,
                           create_directory=True,
                           init_wandb=False,
                           directory_prefix="inference_")

    dataset = SharedDataset(args)
    dataset.load_datasets(args, 0, 1)

    model = Model(dataset, args).to(device)
    model.load_state_dict(checkpoint["model"])

    print("inference of validation data", flush=True)
    predict(model,
            dataset.val,
            args.validation_data,
            args,
            directory,
            0,
Ejemplo n.º 2
0
                        help="Number of CPU threads for dataloaders.")
    parser.add_argument("--rho",
                        default=0.05,
                        type=int,
                        help="Rho parameter for SAM.")
    parser.add_argument("--weight_decay",
                        default=0.0005,
                        type=float,
                        help="L2 weight decay.")
    parser.add_argument("--width_factor",
                        default=8,
                        type=int,
                        help="How many times wider compared to normal ResNet.")
    args = parser.parse_args()

    initialize(args, seed=42)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    dataset = Cifar(args.batch_size, args.threads)
    log = Log(log_each=10)
    model = WideResNet(args.depth,
                       args.width_factor,
                       args.dropout,
                       in_channels=3,
                       labels=10).to(device)

    base_optimizer = torch.optim.SGD
    optimizer = SAM(model.parameters(),
                    base_optimizer,
                    rho=args.rho,
                    lr=args.learning_rate,
Ejemplo n.º 3
0
def main_worker(gpu, n_gpus_per_node, args):
    is_master = gpu == 0
    directory = initialize(args,
                           create_directory=is_master,
                           init_wandb=args.log_wandb and is_master)

    os.environ["MASTER_ADDR"] = "localhost"
    if "MASTER_PORT" not in os.environ:
        os.environ["MASTER_PORT"] = "12345"

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method="env://",
                                world_size=n_gpus_per_node,
                                rank=gpu)

    dataset = SharedDataset(args)
    dataset.load_datasets(args, gpu, n_gpus_per_node)

    model = Model(dataset, args)
    parameters = [{
        "params": p,
        "weight_decay": args.encoder_weight_decay
    } for p in model.get_encoder_parameters(args.n_encoder_layers)
                  ] + [{
                      "params": model.get_decoder_parameters(),
                      "weight_decay": args.decoder_weight_decay
                  }]
    optimizer = AdamW(parameters, betas=(0.9, args.beta_2))
    scheduler = multi_scheduler_wrapper(optimizer, args)
    autoclip = AutoClip([
        p for name, p in model.named_parameters() if "loss_weights" not in name
    ])
    if args.balance_loss_weights:
        loss_weight_learner = LossWeightLearner(args, model, n_gpus_per_node)

    if is_master:
        if args.log_wandb:
            import wandb
            wandb.watch(model, log=args.wandb_log_mode)
        print(f"\nmodel: {model}\n")
        log = Log(dataset,
                  model,
                  optimizer,
                  args,
                  directory,
                  log_each=10,
                  log_wandb=args.log_wandb)

    torch.cuda.set_device(gpu)
    model = model.cuda(gpu)

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model,
                                                          device_ids=[gpu])
        raw_model = model.module
    else:
        raw_model = model

    force_cpu_dev = False  #changed - along below
    if force_cpu_dev:
        dev0 = torch.device("cpu")
        model.to(dev0)
        gpu = dev0

    for epoch in range(args.epochs):

        #
        # TRAINING
        #

        model.train()
        if is_master:
            log.train(len_dataset=dataset.train_size)

        i = 0
        model.zero_grad()
        losses_over_bs = []  #changed - added to accum losses on
        for batch in dataset.train:
            if not force_cpu_dev:  #changed - if clause added
                batch = Batch.to(batch, gpu)
            total_loss, losses, stats = model(batch)

            for head in raw_model.heads:
                stats.update(head.loss_weights_dict())

            if args.balance_loss_weights:
                loss_weight_learner.compute_grad(losses, epoch)

            losses_over_bs.append(
                total_loss.item())  #changed - added for analyzing loss
            total_loss.backward()

            if (i + 1) % args.accumulation_steps == 0:
                grad_norm = autoclip()

                if args.balance_loss_weights:
                    loss_weight_learner.step(epoch)
                scheduler(epoch)
                optimizer.step()
                model.zero_grad()

                if is_master:
                    with torch.no_grad():
                        batch_size = batch["every_input"][0].size(
                            0) * args.accumulation_steps
                        log(batch_size,
                            stats,
                            args.frameworks,
                            grad_norm=grad_norm,
                            learning_rates=scheduler.lr() +
                            [loss_weight_learner.scheduler.lr()])

            del total_loss, losses

            i += 1

        if not is_master:
            continue

        #
        # VALIDATION CROSS-ENTROPIES
        #
        model.eval()
        log.eval(len_dataset=dataset.val_size)

        with torch.no_grad():
            for batch in dataset.val:
                try:
                    _, _, stats = model(Batch.to(batch, gpu))

                    batch_size = batch["every_input"][0].size(0)
                    log(batch_size, stats, args.frameworks)
                except RuntimeError as e:
                    if 'out of memory' in str(e):
                        print('| WARNING: ran out of memory, skipping batch')
                        if hasattr(torch.cuda, 'empty_cache'):
                            torch.cuda.empty_cache()
                    else:
                        raise e

        lobs = np.array(losses_over_bs)  #changed to be uses with below
        print(
            str(lobs.mean()) + "; " + str(lobs.max()) + "; " +
            str(lobs.min()))  #changed - print loss for epoch
        log.flush()

        #
        # VALIDATION MRP-SCORES
        #
        predict(raw_model,
                dataset.val,
                args.validation_data,
                args,
                directory,
                gpu,
                run_evaluation=True,
                epoch=epoch)

    #
    # TEST PREDICTION
    #
    test_fpath = f"{directory}/test_predictions/"  #changed - catch exists error
    if not os.path.exists(test_fpath):
        os.mkdir(test_fpath)

    #os.mkdir(f"{directory}/test_predictions/")
    predict(raw_model, dataset.test, args.test_data, args,
            f"{directory}/test_predictions/", gpu)