optimizer.first_step(zero_grad=True)

            # second forward-backward step
            smooth_crossentropy(model(inputs), targets).mean().backward()
            optimizer.second_step(zero_grad=True)

            with torch.no_grad():
                correct = torch.argmax(predictions.data, 1) == targets
                log(model, loss.cpu(), correct.cpu(), scheduler.lr())
                scheduler(epoch)

        model.eval()
        log.eval(len_dataset=len(dataset.test))

        with torch.no_grad():
            for batch in dataset.test:
                inputs, targets = (b.to(device) for b in batch)

                predictions = model(inputs)
                loss = smooth_crossentropy(predictions, targets)
                correct = torch.argmax(predictions, 1) == targets
                log(model, loss.cpu(), correct.cpu())
        if epoch % 50 == 0:
            PATH = './trained_models/sam_net_' + str(epoch).zfill(3) + '.pth'
            torch.save(model.state_dict(), PATH)

    PATH = './trained_models/sam_net_250_final.pth'
    torch.save(model.state_dict(), PATH)

    log.flush()
示例#2
0
文件: train.py 项目: brkmnd/DcrParser
def main_worker(gpu, n_gpus_per_node, args):
    is_master = gpu == 0
    directory = initialize(args,
                           create_directory=is_master,
                           init_wandb=args.log_wandb and is_master)

    os.environ["MASTER_ADDR"] = "localhost"
    if "MASTER_PORT" not in os.environ:
        os.environ["MASTER_PORT"] = "12345"

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method="env://",
                                world_size=n_gpus_per_node,
                                rank=gpu)

    dataset = SharedDataset(args)
    dataset.load_datasets(args, gpu, n_gpus_per_node)

    model = Model(dataset, args)
    parameters = [{
        "params": p,
        "weight_decay": args.encoder_weight_decay
    } for p in model.get_encoder_parameters(args.n_encoder_layers)
                  ] + [{
                      "params": model.get_decoder_parameters(),
                      "weight_decay": args.decoder_weight_decay
                  }]
    optimizer = AdamW(parameters, betas=(0.9, args.beta_2))
    scheduler = multi_scheduler_wrapper(optimizer, args)
    autoclip = AutoClip([
        p for name, p in model.named_parameters() if "loss_weights" not in name
    ])
    if args.balance_loss_weights:
        loss_weight_learner = LossWeightLearner(args, model, n_gpus_per_node)

    if is_master:
        if args.log_wandb:
            import wandb
            wandb.watch(model, log=args.wandb_log_mode)
        print(f"\nmodel: {model}\n")
        log = Log(dataset,
                  model,
                  optimizer,
                  args,
                  directory,
                  log_each=10,
                  log_wandb=args.log_wandb)

    torch.cuda.set_device(gpu)
    model = model.cuda(gpu)

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model,
                                                          device_ids=[gpu])
        raw_model = model.module
    else:
        raw_model = model

    force_cpu_dev = False  #changed - along below
    if force_cpu_dev:
        dev0 = torch.device("cpu")
        model.to(dev0)
        gpu = dev0

    for epoch in range(args.epochs):

        #
        # TRAINING
        #

        model.train()
        if is_master:
            log.train(len_dataset=dataset.train_size)

        i = 0
        model.zero_grad()
        losses_over_bs = []  #changed - added to accum losses on
        for batch in dataset.train:
            if not force_cpu_dev:  #changed - if clause added
                batch = Batch.to(batch, gpu)
            total_loss, losses, stats = model(batch)

            for head in raw_model.heads:
                stats.update(head.loss_weights_dict())

            if args.balance_loss_weights:
                loss_weight_learner.compute_grad(losses, epoch)

            losses_over_bs.append(
                total_loss.item())  #changed - added for analyzing loss
            total_loss.backward()

            if (i + 1) % args.accumulation_steps == 0:
                grad_norm = autoclip()

                if args.balance_loss_weights:
                    loss_weight_learner.step(epoch)
                scheduler(epoch)
                optimizer.step()
                model.zero_grad()

                if is_master:
                    with torch.no_grad():
                        batch_size = batch["every_input"][0].size(
                            0) * args.accumulation_steps
                        log(batch_size,
                            stats,
                            args.frameworks,
                            grad_norm=grad_norm,
                            learning_rates=scheduler.lr() +
                            [loss_weight_learner.scheduler.lr()])

            del total_loss, losses

            i += 1

        if not is_master:
            continue

        #
        # VALIDATION CROSS-ENTROPIES
        #
        model.eval()
        log.eval(len_dataset=dataset.val_size)

        with torch.no_grad():
            for batch in dataset.val:
                try:
                    _, _, stats = model(Batch.to(batch, gpu))

                    batch_size = batch["every_input"][0].size(0)
                    log(batch_size, stats, args.frameworks)
                except RuntimeError as e:
                    if 'out of memory' in str(e):
                        print('| WARNING: ran out of memory, skipping batch')
                        if hasattr(torch.cuda, 'empty_cache'):
                            torch.cuda.empty_cache()
                    else:
                        raise e

        lobs = np.array(losses_over_bs)  #changed to be uses with below
        print(
            str(lobs.mean()) + "; " + str(lobs.max()) + "; " +
            str(lobs.min()))  #changed - print loss for epoch
        log.flush()

        #
        # VALIDATION MRP-SCORES
        #
        predict(raw_model,
                dataset.val,
                args.validation_data,
                args,
                directory,
                gpu,
                run_evaluation=True,
                epoch=epoch)

    #
    # TEST PREDICTION
    #
    test_fpath = f"{directory}/test_predictions/"  #changed - catch exists error
    if not os.path.exists(test_fpath):
        os.mkdir(test_fpath)

    #os.mkdir(f"{directory}/test_predictions/")
    predict(raw_model, dataset.test, args.test_data, args,
            f"{directory}/test_predictions/", gpu)