示例#1
0
def main(args):
    # gpu or cpu
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    args = utils.setup_experiment(args)
    utils.init_logging(args)

    # Loading models
    MODEL_PATH_LOAD = "../lidar_experiments/2d/lidar_unet2d/lidar-unet2d-Nov-08-16:29:49/checkpoints/checkpoint_best.pt"

    train_new_model = True

    # Build data loaders, a model and an optimizer
    if train_new_model:
        model = models.build_model(args).to(device)
    else:
        model = models.build_model(args)
        model.load_state_dict(torch.load(args.MODEL_PATH_LOAD)['model'][0])
        model.to(device)

    print(model)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[5, 15, 30, 50, 100, 250], gamma=0.5)
    logging.info(
        f"Built a model consisting of {sum(p.numel() for p in model.parameters()):,} parameters"
    )

    if args.resume_training:
        state_dict = utils.load_checkpoint(args, model, optimizer, scheduler)
        global_step = state_dict['last_step']
        start_epoch = int(state_dict['last_step'] /
                          (403200 / state_dict['args'].batch_size)) + 1
    else:
        global_step = -1
        start_epoch = 0

    ## Load the pts files
    # Loads as a list of numpy arrays
    scan_line_tensor = torch.load(args.data_path + 'scan_line_tensor.pts')
    train_idx_list = torch.load(args.data_path + 'train_idx_list.pts')
    valid_idx_list = torch.load(args.data_path + 'valid_idx_list.pts')
    sc = torch.load(args.data_path + 'sc.pts')

    # Dataloaders
    train_dataset = LidarLstmDataset(scan_line_tensor, train_idx_list,
                                     args.seq_len, args.mask_pts_per_seq)
    valid_dataset = LidarLstmDataset(scan_line_tensor, valid_idx_list,
                                     args.seq_len, args.mask_pts_per_seq)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               num_workers=4,
                                               shuffle=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               num_workers=4,
                                               shuffle=True)

    # Track moving average of loss values
    train_meters = {
        name: utils.RunningAverageMeter(0.98)
        for name in (["train_loss"])
    }
    valid_meters = {name: utils.AverageMeter() for name in (["valid_loss"])}
    writer = SummaryWriter(
        log_dir=args.experiment_dir) if not args.no_visual else None

    ##################################################
    # TRAINING
    for epoch in range(start_epoch, args.num_epochs):
        if args.resume_training:
            if epoch % 1 == 0:
                optimizer.param_groups[0]["lr"] /= 2
                print('learning rate reduced by factor of 2')

        train_bar = utils.ProgressBar(train_loader, epoch)
        for meter in train_meters.values():
            meter.reset()

    #     epoch_loss_sum = 0
        for batch_id, (clean, mask) in enumerate(train_bar):
            # dataloader returns [clean, mask] list
            model.train()
            global_step += 1
            inputs = clean.to(device)
            mask_inputs = mask.to(device)
            # only use the mask part of the outputs
            raw_outputs = model(inputs, mask_inputs)
            outputs = (
                1 - mask_inputs[:, :3, :, :]
            ) * raw_outputs + mask_inputs[:, :3, :, :] * inputs[:, :3, :, :]

            if args.wtd_loss:
                loss = weighted_MSELoss(outputs, inputs[:, :3, :, :],
                                        sc) / (inputs.size(0) *
                                               (args.mask_pts_per_seq**2))
                # Regularization?

            else:
                # normalized by the number of masked points
                loss = F.mse_loss(outputs, inputs[:,:3,:,:], reduction="sum") / \
                       (inputs.size(0) * (args.mask_pts_per_seq**2))

            model.zero_grad()
            loss.backward()
            optimizer.step()
            #         epoch_loss_sum += loss * inputs.size(0)
            train_meters["train_loss"].update(loss)
            train_bar.log(dict(**train_meters,
                               lr=optimizer.param_groups[0]["lr"]),
                          verbose=True)

            if writer is not None and global_step % args.log_interval == 0:
                writer.add_scalar("lr", optimizer.param_groups[0]["lr"],
                                  global_step)
                writer.add_scalar("loss/train", loss.item(), global_step)
                gradients = torch.cat([
                    p.grad.view(-1)
                    for p in model.parameters() if p.grad is not None
                ],
                                      dim=0)
                writer.add_histogram("gradients", gradients, global_step)
                sys.stdout.flush()
    #     epoch_loss = epoch_loss_sum / len(train_loader.dataset)

        if epoch % args.valid_interval == 0:
            model.eval()
            for meter in valid_meters.values():
                meter.reset()

            valid_bar = utils.ProgressBar(valid_loader)
            val_loss = 0
            for sample_id, (clean, mask) in enumerate(valid_bar):
                with torch.no_grad():
                    inputs = clean.to(device)
                    mask_inputs = mask.to(device)
                    # only use the mask part of the outputs
                    raw_output = model(inputs, mask_inputs)
                    output = (
                        1 - mask_inputs[:, :3, :, :]
                    ) * raw_output + mask_inputs[:, :3, :, :] * inputs[:, :
                                                                       3, :, :]

                    # TO DO, only run loss on masked part of output

                    if args.wtd_loss:
                        val_loss = weighted_MSELoss(
                            output, inputs[:, :3, :, :],
                            sc) / (inputs.size(0) * (args.mask_pts_per_seq**2))
                    else:
                        # normalized by the number of masked points
                        val_loss = F.mse_loss(output, inputs[:,:3,:,:], reduction="sum")/(inputs.size(0)* \
                                                                                        (args.mask_pts_per_seq**2))

                    valid_meters["valid_loss"].update(val_loss.item())

            if writer is not None:
                writer.add_scalar("loss/valid", valid_meters['valid_loss'].avg,
                                  global_step)
                sys.stdout.flush()

            logging.info(
                train_bar.print(
                    dict(**train_meters,
                         **valid_meters,
                         lr=optimizer.param_groups[0]["lr"])))
            utils.save_checkpoint(args,
                                  global_step,
                                  model,
                                  optimizer,
                                  score=valid_meters["valid_loss"].avg,
                                  mode="min")
        scheduler.step()

    logging.info(
        f"Done training! Best Loss {utils.save_checkpoint.best_score:.3f} obtained after step {utils.save_checkpoint.best_step}."
    )
示例#2
0
if __name__ == '__main__':

    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = build_model_tabular(args, 1, regularization_fns).to(device)
    if args.spectral_norm: add_spectral_norm(model)
    set_cnf_options(args, model)

    #logger.info(model)
    logger.info("Number of trainable parameters: {}".format(
        count_parameters(model)))

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)

    time_meter = utils.RunningAverageMeter(0.93)
    loss_meter = utils.RunningAverageMeter(0.93)
    nfef_meter = utils.RunningAverageMeter(0.93)
    nfeb_meter = utils.RunningAverageMeter(0.93)
    tt_meter = utils.RunningAverageMeter(0.93)

    end = time.time()
    best_loss = float('inf')
    model.train()

    # Get truth and reco data

    xTrain, xTest, uniformityWeightsTrain, uniformityWeightsTest, yTrain, yTest, inputScaler, outputScaler = prepMatchedData(
        args.h5Path,
        args.nTrainSamp,
        args.nTestSamp,
示例#3
0
def main(args):
    if not torch.cuda.is_available():
        raise NotImplementedError("Training on CPU is not supported.")
    utils.setup_experiment(args)
    utils.init_logging(args)

    train_loaders, valid_loaders = data.build_dataset(
        args.dataset, args.data_path, batch_size=args.batch_size)
    model = models.build_model(args).cuda()
    optimizer = optim.build_optimizer(args, model.parameters())
    logging.info(
        f"Built a model consisting of {sum(p.numel() for p in model.parameters() if p.requires_grad):,} parameters"
    )

    meters = {
        name: utils.RunningAverageMeter(0.98)
        for name in (["loss", "context", "graph", "target"])
    }
    acc_names = ["overall"
                 ] + [f"task{idx}" for idx in range(len(valid_loaders))]
    acc_meters = {name: utils.AverageMeter() for name in acc_names}
    writer = SummaryWriter(
        log_dir=args.experiment_dir) if not args.no_visual else None

    global_step = -1
    for epoch in range(args.num_epochs):
        acc_tasks = {f"task{idx}": None for idx in range(len(valid_loaders))}
        for task_id, train_loader in enumerate(train_loaders):
            for repeat in range(args.num_repeats_per_task):
                train_bar = utils.ProgressBar(train_loader,
                                              epoch,
                                              prefix=f"task {task_id}")
                for meter in meters.values():
                    meter.reset()

                for batch_id, (images, labels) in enumerate(train_bar):
                    model.train()
                    global_step += 1
                    images, labels = images.cuda(), labels.cuda()
                    outputs = model(images, labels, task_id=task_id)

                    if global_step == 0:
                        continue
                    loss = outputs["loss"]
                    model.zero_grad()
                    loss.backward()
                    optimizer.step()

                    meters["loss"].update(loss.item())
                    meters["context"].update(outputs["context_loss"].item())
                    meters["target"].update(outputs["target_loss"].item())
                    meters["graph"].update(outputs["graph_loss"].item())
                    train_bar.log(dict(
                        **meters,
                        lr=optimizer.get_lr(),
                    ))

                if writer is not None:
                    writer.add_scalar("loss/train", loss.item(), global_step)
                    gradients = torch.cat([
                        p.grad.view(-1)
                        for p in model.parameters() if p.grad is not None
                    ],
                                          dim=0)
                    writer.add_histogram("gradients", gradients, global_step)

            model.eval()
            for meter in acc_meters.values():
                meter.reset()
            for idx, valid_loader in enumerate(valid_loaders):
                valid_bar = utils.ProgressBar(valid_loader,
                                              epoch,
                                              prefix=f"task {task_id}")
                for batch_id, (images, labels) in enumerate(valid_bar):
                    model.eval()
                    with torch.no_grad():
                        images, labels = images.cuda(), labels.cuda()
                        outputs = model.predict(images, labels, task_id=idx)
                        correct = outputs["preds"].eq(labels).sum().item()
                        acc_meters[f"task{idx}"].update(100 * correct,
                                                        n=len(images))
                acc_meters["overall"].update(acc_meters[f"task{idx}"].avg)

            acc_tasks[f"task{task_id}"] = acc_meters[f"task{task_id}"].avg
            if writer is not None:
                for name, meter in acc_meters.items():
                    writer.add_scalar(f"accuracy/{name}", meter.avg,
                                      global_step)
            logging.info(
                train_bar.print(
                    dict(**meters, **acc_meters, lr=optimizer.get_lr())))
            utils.save_checkpoint(args,
                                  global_step,
                                  model,
                                  optimizer,
                                  score=acc_meters["overall"].avg,
                                  mode="max")

    bwt = sum(acc_meters[task].avg - acc
              for task, acc in acc_tasks.items()) / (len(valid_loaders) - 1)
    logging.info(
        f"Done training! Final accuracy {acc_meters['overall'].avg:.4f}, backward transfer {bwt:.4f}."
    )
示例#4
0
def main(args):
	device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
	utils.setup_experiment(args)
	utils.init_logging(args)

	# Build data loaders, a model and an optimizer
	model = models.build_model(args).to(device)
	print(model)
	optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
	scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 60, 70, 80, 90, 100], gamma=0.5)
	logging.info(f"Built a model consisting of {sum(p.numel() for p in model.parameters()):,} parameters")
	
	if args.resume_training:
		state_dict = utils.load_checkpoint(args, model, optimizer, scheduler)
		global_step = state_dict['last_step']
		start_epoch = int(state_dict['last_step']/(403200/state_dict['args'].batch_size))+1
	else:
		global_step = -1
		start_epoch = 0
		
	train_loader, valid_loader, _ = data.build_dataset(args.dataset, args.data_path, batch_size=args.batch_size)
	
	# Track moving average of loss values
	train_meters = {name: utils.RunningAverageMeter(0.98) for name in (["train_loss", "train_psnr", "train_ssim"])}
	valid_meters = {name: utils.AverageMeter() for name in (["valid_psnr", "valid_ssim"])}
	writer = SummaryWriter(log_dir=args.experiment_dir) if not args.no_visual else None

	for epoch in range(start_epoch, args.num_epochs):
		if args.resume_training:
			if epoch %10 == 0:
				optimizer.param_groups[0]["lr"] /= 2
				print('learning rate reduced by factor of 2')
				
		train_bar = utils.ProgressBar(train_loader, epoch)
		for meter in train_meters.values():
			meter.reset()

		for batch_id, inputs in enumerate(train_bar):
			model.train()

			global_step += 1
			inputs = inputs.to(device)
			noise = utils.get_noise(inputs, mode = args.noise_mode, 
												min_noise = args.min_noise/255., max_noise = args.max_noise/255.,
												noise_std = args.noise_std/255.)

			noisy_inputs = noise + inputs;
			outputs = model(noisy_inputs)
			loss = F.mse_loss(outputs, inputs, reduction="sum") / (inputs.size(0) * 2)

			model.zero_grad()
			loss.backward()
			optimizer.step()

			train_psnr = utils.psnr(outputs, inputs)
			train_ssim = utils.ssim(outputs, inputs)
			train_meters["train_loss"].update(loss.item())
			train_meters["train_psnr"].update(train_psnr.item())
			train_meters["train_ssim"].update(train_ssim.item())
			train_bar.log(dict(**train_meters, lr=optimizer.param_groups[0]["lr"]), verbose=True)

			if writer is not None and global_step % args.log_interval == 0:
				writer.add_scalar("lr", optimizer.param_groups[0]["lr"], global_step)
				writer.add_scalar("loss/train", loss.item(), global_step)
				writer.add_scalar("psnr/train", train_psnr.item(), global_step)
				writer.add_scalar("ssim/train", train_ssim.item(), global_step)
				gradients = torch.cat([p.grad.view(-1) for p in model.parameters() if p.grad is not None], dim=0)
				writer.add_histogram("gradients", gradients, global_step)
				sys.stdout.flush()

		if epoch % args.valid_interval == 0:
			model.eval()
			for meter in valid_meters.values():
				meter.reset()

			valid_bar = utils.ProgressBar(valid_loader)
			for sample_id, sample in enumerate(valid_bar):
				with torch.no_grad():
					sample = sample.to(device)
					noise = utils.get_noise(sample, mode = 'S', 
												noise_std = (args.min_noise +  args.max_noise)/(2*255.))

					noisy_inputs = noise + sample;
					output = model(noisy_inputs)
					valid_psnr = utils.psnr(output, sample)
					valid_meters["valid_psnr"].update(valid_psnr.item())
					valid_ssim = utils.ssim(output, sample)
					valid_meters["valid_ssim"].update(valid_ssim.item())

					if writer is not None and sample_id < 10:
						image = torch.cat([sample, noisy_inputs, output], dim=0)
						image = torchvision.utils.make_grid(image.clamp(0, 1), nrow=3, normalize=False)
						writer.add_image(f"valid_samples/{sample_id}", image, global_step)

			if writer is not None:
				writer.add_scalar("psnr/valid", valid_meters['valid_psnr'].avg, global_step)
				writer.add_scalar("ssim/valid", valid_meters['valid_ssim'].avg, global_step)
				sys.stdout.flush()

			logging.info(train_bar.print(dict(**train_meters, **valid_meters, lr=optimizer.param_groups[0]["lr"])))
			utils.save_checkpoint(args, global_step, model, optimizer, score=valid_meters["valid_psnr"].avg, mode="max")
		scheduler.step()

	logging.info(f"Done training! Best PSNR {utils.save_checkpoint.best_score:.3f} obtained after step {utils.save_checkpoint.best_step}.")
示例#5
0
    else:
        kernel_optimizer = optim.Adam(list(kernel_net.parameters()) +
                                      [log_bandwidth],
                                      lr=args.lr,
                                      betas=(.5, .9),
                                      weight_decay=args.critic_weight_decay)

    if args.kernel == "neural":
        if args.k_dim == 1:
            encoder_fn = lambda x: kernel_net(x)[:, None]
        else:
            encoder_fn = kernel_net
    else:
        encoder_fn = lambda x: x

    time_meter = utils.RunningAverageMeter(0.98)
    loss_meter = utils.RunningAverageMeter(0.98)
    ebm_meter = utils.RunningAverageMeter(0.98)

    def sample_data():
        if args.fixed_dataset:
            inds = list(range(args.batch_size))
            np.random.shuffle(inds)
            inds = torch.from_numpy(inds)
            return fixed_data[inds]
        else:
            return trueICA.sample(args.batch_size)

    best_loss = float('inf')
    modelICA.train()
    end = time.time()
示例#6
0
def _main(rank, world_size, args, savepath, logger):

    if rank == 0:
        logger.info(args)
        logger.info(f"Saving to {savepath}")
        tb_writer = SummaryWriter(os.path.join(savepath, "tb_logdir"))

    device = torch.device(
        f'cuda:{rank:d}' if torch.cuda.is_available() else 'cpu')

    if rank == 0:
        if device.type == 'cuda':
            logger.info('Found {} CUDA devices.'.format(
                torch.cuda.device_count()))
            for i in range(torch.cuda.device_count()):
                props = torch.cuda.get_device_properties(i)
                logger.info('{} \t Memory: {:.2f}GB'.format(
                    props.name, props.total_memory / (1024**3)))
        else:
            logger.info('WARNING: Using device {}'.format(device))

    t0, t1 = map(lambda x: cast(x, device), get_t0_t1(args.data))

    train_set = load_data(args.data, split="train")
    val_set = load_data(args.data, split="val")
    test_set = load_data(args.data, split="test")

    train_epoch_iter = EpochBatchIterator(
        dataset=train_set,
        collate_fn=datasets.spatiotemporal_events_collate_fn,
        batch_sampler=train_set.batch_by_size(args.max_events),
        seed=args.seed + rank,
    )
    val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=args.test_bsz,
        shuffle=False,
        collate_fn=datasets.spatiotemporal_events_collate_fn,
    )
    test_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=args.test_bsz,
        shuffle=False,
        collate_fn=datasets.spatiotemporal_events_collate_fn,
    )

    if rank == 0:
        logger.info(
            f"{len(train_set)} training examples, {len(val_set)} val examples, {len(test_set)} test examples"
        )

    x_dim = get_dim(args.data)

    if args.model == "jumpcnf" and args.tpp == "neural":
        model = JumpCNFSpatiotemporalModel(
            dim=x_dim,
            hidden_dims=list(map(int, args.hdims.split("-"))),
            tpp_hidden_dims=list(map(int, args.tpp_hdims.split("-"))),
            actfn=args.actfn,
            tpp_cond=args.tpp_cond,
            tpp_style=args.tpp_style,
            tpp_actfn=args.tpp_actfn,
            share_hidden=args.share_hidden,
            solve_reverse=args.solve_reverse,
            tol=args.tol,
            otreg_strength=args.otreg_strength,
            tpp_otreg_strength=args.tpp_otreg_strength,
            layer_type=args.layer_type,
        ).to(device)
    elif args.model == "attncnf" and args.tpp == "neural":
        model = SelfAttentiveCNFSpatiotemporalModel(
            dim=x_dim,
            hidden_dims=list(map(int, args.hdims.split("-"))),
            tpp_hidden_dims=list(map(int, args.tpp_hdims.split("-"))),
            actfn=args.actfn,
            tpp_cond=args.tpp_cond,
            tpp_style=args.tpp_style,
            tpp_actfn=args.tpp_actfn,
            share_hidden=args.share_hidden,
            solve_reverse=args.solve_reverse,
            l2_attn=args.l2_attn,
            tol=args.tol,
            otreg_strength=args.otreg_strength,
            tpp_otreg_strength=args.tpp_otreg_strength,
            layer_type=args.layer_type,
            lowvar_trace=not args.naive_hutch,
        ).to(device)
    elif args.model == "cond_gmm" and args.tpp == "neural":
        model = JumpGMMSpatiotemporalModel(
            dim=x_dim,
            hidden_dims=list(map(int, args.hdims.split("-"))),
            tpp_hidden_dims=list(map(int, args.tpp_hdims.split("-"))),
            actfn=args.actfn,
            tpp_cond=args.tpp_cond,
            tpp_style=args.tpp_style,
            tpp_actfn=args.tpp_actfn,
            share_hidden=args.share_hidden,
            tol=args.tol,
            tpp_otreg_strength=args.tpp_otreg_strength,
        ).to(device)
    else:
        # Mix and match between spatial and temporal models.
        if args.tpp == "poisson":
            tpp_model = HomogeneousPoissonPointProcess()
        elif args.tpp == "hawkes":
            tpp_model = HawkesPointProcess()
        elif args.tpp == "correcting":
            tpp_model = SelfCorrectingPointProcess()
        elif args.tpp == "neural":
            tpp_hidden_dims = list(map(int, args.tpp_hdims.split("-")))
            tpp_model = NeuralPointProcess(
                cond_dim=x_dim,
                hidden_dims=tpp_hidden_dims,
                cond=args.tpp_cond,
                style=args.tpp_style,
                actfn=args.tpp_actfn,
                otreg_strength=args.tpp_otreg_strength,
                tol=args.tol)
        else:
            raise ValueError(f"Invalid tpp model {args.tpp}")

        if args.model == "gmm":
            model = CombinedSpatiotemporalModel(GaussianMixtureSpatialModel(),
                                                tpp_model).to(device)
        elif args.model == "cnf":
            model = CombinedSpatiotemporalModel(
                IndependentCNF(dim=x_dim,
                               hidden_dims=list(map(int,
                                                    args.hdims.split("-"))),
                               layer_type=args.layer_type,
                               actfn=args.actfn,
                               tol=args.tol,
                               otreg_strength=args.otreg_strength,
                               squash_time=True), tpp_model).to(device)
        elif args.model == "tvcnf":
            model = CombinedSpatiotemporalModel(
                IndependentCNF(dim=x_dim,
                               hidden_dims=list(map(int,
                                                    args.hdims.split("-"))),
                               layer_type=args.layer_type,
                               actfn=args.actfn,
                               tol=args.tol,
                               otreg_strength=args.otreg_strength),
                tpp_model).to(device)
        elif args.model == "jumpcnf":
            model = CombinedSpatiotemporalModel(
                JumpCNF(dim=x_dim,
                        hidden_dims=list(map(int, args.hdims.split("-"))),
                        layer_type=args.layer_type,
                        actfn=args.actfn,
                        tol=args.tol,
                        otreg_strength=args.otreg_strength),
                tpp_model).to(device)
        elif args.model == "attncnf":
            model = CombinedSpatiotemporalModel(
                SelfAttentiveCNF(dim=x_dim,
                                 hidden_dims=list(
                                     map(int, args.hdims.split("-"))),
                                 layer_type=args.layer_type,
                                 actfn=args.actfn,
                                 l2_attn=args.l2_attn,
                                 tol=args.tol,
                                 otreg_strength=args.otreg_strength),
                tpp_model).to(device)
        else:
            raise ValueError(f"Invalid model {args.model}")

    params = []
    attn_params = []
    for name, p in model.named_parameters():
        if "self_attns" in name:
            attn_params.append(p)
        else:
            params.append(p)

    optimizer = torch.optim.AdamW([{
        "params": params
    }, {
        "params": attn_params
    }],
                                  lr=args.lr,
                                  weight_decay=args.weight_decay,
                                  betas=(0.9, 0.98))

    if rank == 0:
        ema = utils.ExponentialMovingAverage(model)

    model = DDP(model, device_ids=[rank], find_unused_parameters=True)

    if rank == 0:
        logger.info(model)

    begin_itr = 0
    checkpt_path = os.path.join(savepath, "model.pth")
    if os.path.exists(checkpt_path):
        # Restart from checkpoint if run is a restart.
        if rank == 0:
            logger.info(f"Resuming checkpoint from {checkpt_path}")
        checkpt = torch.load(checkpt_path, "cpu")
        model.module.load_state_dict(checkpt["state_dict"])
        optimizer.load_state_dict(checkpt["optim_state_dict"])
        begin_itr = checkpt["itr"] + 1

    elif args.resume:
        # Check the resume flag if run is new.
        if rank == 0:
            logger.info(f"Resuming model from {args.resume}")
        checkpt = torch.load(args.resume, "cpu")
        model.module.load_state_dict(checkpt["state_dict"])
        optimizer.load_state_dict(checkpt["optim_state_dict"])
        begin_itr = checkpt["itr"] + 1

    space_loglik_meter = utils.RunningAverageMeter(0.98)
    time_loglik_meter = utils.RunningAverageMeter(0.98)
    gradnorm_meter = utils.RunningAverageMeter(0.98)

    model.train()
    start_time = time.time()
    iteration_counter = itertools.count(begin_itr)
    begin_epoch = begin_itr // len(train_epoch_iter)
    for epoch in range(begin_epoch,
                       math.ceil(args.num_iterations / len(train_epoch_iter))):
        batch_iter = train_epoch_iter.next_epoch_itr(shuffle=True)
        for batch in batch_iter:
            itr = next(iteration_counter)

            optimizer.zero_grad()

            event_times, spatial_locations, input_mask = map(
                lambda x: cast(x, device), batch)
            N, T = input_mask.shape
            num_events = input_mask.sum()

            if num_events == 0:
                raise RuntimeError("Got batch with no observations.")

            space_loglik, time_loglik = model(event_times, spatial_locations,
                                              input_mask, t0, t1)

            space_loglik = space_loglik.sum() / num_events
            time_loglik = time_loglik.sum() / num_events
            loglik = time_loglik + space_loglik

            space_loglik_meter.update(space_loglik.item())
            time_loglik_meter.update(time_loglik.item())

            loss = loglik.mul(-1.0).mean()
            loss.backward()

            # Set learning rate
            total_itrs = math.ceil(
                args.num_iterations /
                len(train_epoch_iter)) * len(train_epoch_iter)
            lr = learning_rate_schedule(itr, args.warmup_itrs, args.lr,
                                        total_itrs)
            set_learning_rate(optimizer, lr)

            grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_(
                model.parameters(), max_norm=args.gradclip).item()
            gradnorm_meter.update(grad_norm)

            optimizer.step()

            if rank == 0:
                if itr > 0.8 * args.num_iterations:
                    ema.apply()
                else:
                    ema.apply(decay=0.0)

            if rank == 0:
                tb_writer.add_scalar("train/lr", lr, itr)
                tb_writer.add_scalar("train/temporal_loss", time_loglik.item(),
                                     itr)
                tb_writer.add_scalar("train/spatial_loss", space_loglik.item(),
                                     itr)
                tb_writer.add_scalar("train/grad_norm", grad_norm, itr)

            if itr % args.logfreq == 0:
                elapsed_time = time.time() - start_time

                # Average NFE across devices.
                nfe = 0
                for m in model.modules():
                    if isinstance(m, TimeVariableCNF) or isinstance(
                            m, TimeVariableODE):
                        nfe += m.nfe
                nfe = torch.tensor(nfe).to(device)
                dist.all_reduce(nfe, op=dist.ReduceOp.SUM)
                nfe = nfe // world_size

                # Sum memory usage across devices.
                mem = torch.tensor(memory_usage_psutil()).float().to(device)
                dist.all_reduce(mem, op=dist.ReduceOp.SUM)

                if rank == 0:
                    logger.info(
                        f"Iter {itr} | Epoch {epoch} | LR {lr:.5f} | Time {elapsed_time:.1f}"
                        f" | Temporal {time_loglik_meter.val:.4f}({time_loglik_meter.avg:.4f})"
                        f" | Spatial {space_loglik_meter.val:.4f}({space_loglik_meter.avg:.4f})"
                        f" | GradNorm {gradnorm_meter.val:.2f}({gradnorm_meter.avg:.2f})"
                        f" | NFE {nfe.item()}"
                        f" | Mem {mem.item():.2f} MB")

                    tb_writer.add_scalar("train/nfe", nfe, itr)
                    tb_writer.add_scalar("train/time_per_itr",
                                         elapsed_time / args.logfreq, itr)

                start_time = time.time()

            if rank == 0 and itr % args.testfreq == 0:
                # ema.swap()
                val_space_loglik, val_time_loglik = validate(
                    model, val_loader, t0, t1, device)
                test_space_loglik, test_time_loglik = validate(
                    model, test_loader, t0, t1, device)
                # ema.swap()
                logger.info(
                    f"[Test] Iter {itr} | Val Temporal {val_time_loglik:.4f} | Val Spatial {val_space_loglik:.4f}"
                    f" | Test Temporal {test_time_loglik:.4f} | Test Spatial {test_space_loglik:.4f}"
                )

                tb_writer.add_scalar("val/temporal_loss", val_time_loglik, itr)
                tb_writer.add_scalar("val/spatial_loss", val_space_loglik, itr)

                tb_writer.add_scalar("test/temporal_loss", test_time_loglik,
                                     itr)
                tb_writer.add_scalar("test/spatial_loss", test_space_loglik,
                                     itr)

                torch.save(
                    {
                        "itr": itr,
                        "state_dict": model.module.state_dict(),
                        "optim_state_dict": optimizer.state_dict(),
                        "ema_parmas": ema.ema_params,
                    }, checkpt_path)

                start_time = time.time()

    if rank == 0:
        tb_writer.close()
示例#7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data',
                        choices=[
                            'swissroll', '8gaussians', 'pinwheel', 'circles',
                            'moons', '2spirals', 'checkerboard', 'rings'
                        ],
                        type=str,
                        default='moons')
    parser.add_argument('--niters', type=int, default=10000)
    parser.add_argument('--batch_size', type=int, default=100)
    parser.add_argument('--test_batch_size', type=int, default=1000)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=0)
    parser.add_argument('--critic_weight_decay', type=float, default=0)
    parser.add_argument('--save', type=str, default='/tmp/test_lsd')
    parser.add_argument('--mode',
                        type=str,
                        default="lsd",
                        choices=['lsd', 'sm'])
    parser.add_argument('--viz_freq', type=int, default=100)
    parser.add_argument('--save_freq', type=int, default=10000)
    parser.add_argument('--log_freq', type=int, default=100)
    parser.add_argument('--base_dist', action="store_true")
    parser.add_argument('--c_iters', type=int, default=5)
    parser.add_argument('--l2', type=float, default=10.)
    parser.add_argument('--exact_trace', action="store_true")
    parser.add_argument('--n_steps', type=int, default=10)
    args = parser.parse_args()

    # logger
    utils.makedirs(args.save)
    logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'),
                              filepath=os.path.abspath(__file__))
    logger.info(args)

    # fit a gaussian to the training data
    init_size = 1000
    init_batch = sample_data(args, init_size).requires_grad_()
    mu, std = init_batch.mean(0), init_batch.std(0)
    base_dist = distributions.Normal(mu, std)

    # neural netz
    critic = networks.SmallMLP(2, n_out=2)
    net = networks.SmallMLP(2)

    ebm = EBM(net, base_dist if args.base_dist else None)
    ebm.to(device)
    critic.to(device)

    # for sampling
    init_fn = lambda: base_dist.sample_n(args.test_batch_size)
    cov = utils.cov(init_batch)
    sampler = HMCSampler(ebm,
                         .3,
                         5,
                         init_fn,
                         device=device,
                         covariance_matrix=cov)

    logger.info(ebm)
    logger.info(critic)

    # optimizers
    optimizer = optim.Adam(ebm.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay,
                           betas=(.0, .999))
    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=args.lr,
                                  betas=(.0, .999),
                                  weight_decay=args.critic_weight_decay)

    time_meter = utils.RunningAverageMeter(0.98)
    loss_meter = utils.RunningAverageMeter(0.98)

    ebm.train()
    end = time.time()
    for itr in range(args.niters):

        optimizer.zero_grad()
        critic_optimizer.zero_grad()

        x = sample_data(args, args.batch_size)
        x.requires_grad_()

        if args.mode == "lsd":
            # our method

            # compute dlogp(x)/dx
            logp_u = ebm(x)
            sq = keep_grad(logp_u.sum(), x)
            fx = critic(x)
            # compute (dlogp(x)/dx)^T * f(x)
            sq_fx = (sq * fx).sum(-1)

            # compute/estimate Tr(df/dx)
            if args.exact_trace:
                tr_dfdx = exact_jacobian_trace(fx, x)
            else:
                tr_dfdx = approx_jacobian_trace(fx, x)

            stats = (sq_fx + tr_dfdx)
            loss = stats.mean()  # estimate of S(p, q)
            l2_penalty = (
                fx * fx).sum(1).mean() * args.l2  # penalty to enforce f \in F

            # adversarial!
            if args.c_iters > 0 and itr % (args.c_iters + 1) != 0:
                (-1. * loss + l2_penalty).backward()
                critic_optimizer.step()
            else:
                loss.backward()
                optimizer.step()

        elif args.mode == "sm":
            # score matching for reference
            fx = ebm(x)
            dfdx = torch.autograd.grad(fx.sum(),
                                       x,
                                       retain_graph=True,
                                       create_graph=True)[0]
            eps = torch.randn_like(dfdx)  # use hutchinson here as well
            epsH = torch.autograd.grad(dfdx,
                                       x,
                                       grad_outputs=eps,
                                       create_graph=True,
                                       retain_graph=True)[0]

            trH = (epsH * eps).sum(1)
            norm_s = (dfdx * dfdx).sum(1)

            loss = (trH + .5 * norm_s).mean()
            loss.backward()
            optimizer.step()
        else:
            assert False

        loss_meter.update(loss.item())
        time_meter.update(time.time() - end)

        if itr % args.log_freq == 0:
            log_message = (
                'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.4f}({:.4f})'.
                format(itr, time_meter.val, time_meter.avg, loss_meter.val,
                       loss_meter.avg))
            logger.info(log_message)

        if itr % args.save_freq == 0 or itr == args.niters:
            ebm.cpu()
            utils.makedirs(args.save)
            torch.save({
                'args': args,
                'state_dict': ebm.state_dict(),
            }, os.path.join(args.save, 'checkpt.pth'))
            ebm.to(device)

        if itr % args.viz_freq == 0:
            # plot dat
            plt.clf()
            npts = 100
            p_samples = toy_data.inf_train_gen(args.data, batch_size=npts**2)
            q_samples = sampler.sample(args.n_steps)

            ebm.cpu()

            x_enc = critic(x)
            xes = x_enc.detach().cpu().numpy()
            trans = xes.min()
            scale = xes.max() - xes.min()
            xes = (xes - trans) / scale * 8 - 4

            plt.figure(figsize=(4, 4))
            visualize_transform(
                [p_samples, q_samples.detach().cpu().numpy(), xes],
                ["data", "model", "embed"], [ebm], ["model"],
                npts=npts)

            fig_filename = os.path.join(args.save, 'figs',
                                        '{:04d}.png'.format(itr))
            utils.makedirs(os.path.dirname(fig_filename))
            plt.savefig(fig_filename)
            plt.close()

            ebm.to(device)
        end = time.time()

    logger.info('Training has finished, can I get a yeet?')