Ejemplo n.º 1
0
def execute(args):
    geometries, forces, features = otp.data(args)

    if args.cg_ones:
        cg_features = torch.ones(
            args.bs, args.ncg, 1, dtype=args.precision, device=args.device
        )
    else:
        cg_features = torch.zeros(
            args.bs, args.ncg, args.ncg, dtype=args.precision, device=args.device
        )
        cg_features.scatter_(
            -1,
            torch.arange(args.ncg, device=args.device)
            .expand(args.bs, args.ncg)
            .unsqueeze(-1),
            1.0,
        )

    encoder = Encoder(args).to(device=args.device)
    decoder = Decoder(args).to(device=args.device)
    optimizer = torch.optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()), lr=args.lr
    )
    temp_sched = temp_scheduler(
        args.epochs,
        args.tdr,
        args.temp,
        args.tmin,
        dtype=args.precision,
        device=args.device,
    )
    n_batches, geometries, forces, features = otp.batch(
        geometries, forces, features, args.bs
    )

    dynamics = []
    summaries = []
    wall_start = perf_counter()
    torch.manual_seed(args.seed)
    for epoch in tqdm(range(args.epochs)):
        for step, batch in tqdm(
            enumerate(torch.randperm(n_batches, device=args.device))
        ):
            feat, geo, force = features[batch], geometries[batch], forces[batch]

            # Auto encoder
            logits = encoder(feat, geo)
            cg_assign, st_cg_assign = gumbel_softmax(
                logits, temp_sched[epoch], dtype=args.precision, device=args.device
            )
            E = cg_assign / cg_assign.sum(1).unsqueeze(1)
            cg_xyz = torch.einsum("zij,zik->zjk", E, geo)

            # End goal is projection of atoms by atomic number onto coarse grained atom.
            relative_xyz = (
                geo.unsqueeze(2).cpu().detach() - cg_xyz.unsqueeze(1).cpu().detach()
            )
            nearest_assign = nearest_assignment(cg_xyz, geo)
            if args.gumble_sm_proj:
                cg_proj = otp.project_onto_cg(relative_xyz, cg_assign, feat, args)
            elif args.nearest:
                cg_proj = otp.project_onto_cg(relative_xyz, nearest_assign, feat, args)
            else:
                cg_proj = otp.project_onto_cg(relative_xyz, st_cg_assign, feat, args)

            pred_sph = decoder(cg_features, cg_xyz.clone().detach())
            cg_proj = cg_proj.reshape_as(pred_sph)
            loss_ae = (cg_proj - pred_sph).pow(2).sum(-1).div(args.atomic_nums).mean()

            if args.fm and epoch >= args.fm_epoch:
                # Force matching
                cg_force_assign, _ = gumbel_softmax(
                    logits,
                    temp_sched[epoch] * args.force_temp_coeff,
                    device=args.device,
                    dtype=args.precision,
                )
                cg_force = torch.einsum("zij,zik->zjk", cg_force_assign, force)
                loss_fm = cg_force.pow(2).sum(-1).mean()

                loss = loss_ae + args.fm_co * loss_fm

            else:
                loss_fm = torch.tensor(0)
                loss = loss_ae

            dynamics.append(
                {
                    "loss_ae": loss_ae.item(),
                    "loss_fm": loss_fm.item(),
                    "loss": loss.item(),
                    "epoch": epoch,
                    "step": step,
                    "batch": batch.item(),
                }
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        wall = perf_counter() - wall_start
        if wall > args.wall:
            break

        summaries.append(
            {
                "loss_ae": loss_ae.item(),
                "loss_fm": loss_fm.item(),
                "loss": loss.item(),
                "epoch": epoch,
                "step": step,
                "batch": batch.item(),
                "cg_xyz": cg_xyz,
                "pred_sph": pred_sph,
                "sph": cg_proj,
                "temp": temp_sched[epoch].item(),
                "gumble": cg_assign,
                "st_gumble": st_cg_assign,
                "nearest": nearest_assign,
            }
        )

    return {
        "args": args,
        "dynamics": dynamics,
        "summaries": summaries,
        # 'train': {
        #     'pred': evaluate(f, features, geometry, train[:len(test)]),
        #     'true': forces[train[:len(test)]],
        # },
        # 'test': {
        #     'pred': evaluate(f, features, geometry, test[:len(train)]),
        #     'true': forces[test[:len(train)]],
        # },
        "encoder": encoder.state_dict() if args.save_state else None,
        "decoder": decoder.state_dict() if args.save_state else None,
    }
Ejemplo n.º 2
0
def execute(args):
    dense_dict = torch.load(args.dense, map_location=args.device)
    geometries, forces, features = otp.data(args)

    encoder_dense = dense.Encoder(
        in_dim=geometries.size(1), out_dim=args.ncg, device=args.device
    ).to(args.device)
    encoder_dense.load_state_dict(dense_dict["encoder"])
    encoder_dense.weight1.detach_()
    decoder_dense = dense.Decoder(in_dim=args.ncg, out_dim=geometries.size(1)).to(
        args.device
    )
    decoder_dense.load_state_dict(dense_dict["decoder"])
    decoder_dense.weight.detach_()

    if args.cg_ones:
        Rs_in = [[(1, 0)]]
    elif args.project_one:
        Rs_in = [[(1, 0), (1, 2)]]
    elif args.soln:
        Rs_in = [[(1, l) for mul, l in enumerate(range(args.proj_lmax + 1))] * 2]
    elif args.cg_specific_atom:
        raise NotImplementedError()
        # Rs_in = [[(1, 0), (1, args.cg_specific_atom)]]
    else:
        Rs_in = [[(args.ncg, 0)]]

    # Encoder... TBD
    decoder = equi.Decoder(args, Rs_in).to(device=args.device)
    # optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=args.lr)
    optimizer = torch.optim.Adam(list(decoder.parameters()), lr=args.lr)
    temp_sched = temp_scheduler(
        args.epochs,
        args.tdr,
        args.temp,
        args.tmin,
        dtype=args.precision,
        device=args.device,
    )
    n_batches, geometries, forces, features = otp.batch(
        geometries, forces, features, args.bs
    )

    dynamics = []
    summaries = []
    wall_start = perf_counter()
    torch.manual_seed(args.seed)
    for epoch in tqdm(range(args.epochs)):
        # for step, batch in tqdm(enumerate(torch.randperm(n_batches, device=args.device))):
        for step, batch in tqdm(
            enumerate([torch.tensor(0, device=args.device)] * n_batches)
        ):
            feat, geo, force = features[batch], geometries[batch], forces[batch]

            # Auto encoder
            cg_xyz = encoder_dense(geo, temp_sched[epoch])
            logits = encoder_dense.weight1.t()
            cg_assign, st_cg_assign = gumbel_softmax(
                logits, temp_sched[epoch], dtype=args.precision, device=args.device
            )
            decoded = decoder_dense(cg_xyz)
            loss_ae_dense = (decoded - geo).pow(2).sum(-1).mean()

            # Calculate Ground Truth
            # End goal is projection of atoms by atomic number onto coarse grained atom.
            relative_xyz = (
                geo.unsqueeze(2).cpu().detach() - cg_xyz.unsqueeze(1).cpu().detach()
            )
            nearest_assign = equi.nearest_assignment(cg_xyz, geo)
            if args.gumble_sm_proj:
                cg_proj = otp.project_onto_cg(relative_xyz, cg_assign, feat, args)
            elif args.nearest:
                cg_proj = otp.project_onto_cg(relative_xyz, nearest_assign, feat, args)
            else:
                cg_proj = otp.project_onto_cg(relative_xyz, st_cg_assign, feat, args)
            cg_proj = cg_proj.reshape(args.bs, args.ncg, -1)

            # Features and Predict
            if args.cg_ones:
                cg_features = torch.ones(
                    args.bs, args.ncg, 1, dtype=args.precision, device=args.device
                )
            elif args.project_one:
                # The purpose is to select the nearest atom to a cg_atom, project it, give a single atom that feature.
                relative_xyz = geo.unsqueeze(2) - cg_xyz.unsqueeze(1)
                nearest_atom_ind = relative_xyz.norm(dim=-1).argmin(1).squeeze()
                cg_atom_ind = 2
                l2_features = otp.project_atom_onto_cg_features(
                    relative_xyz,
                    2,
                    nearest_atom_ind[cg_atom_ind],
                    cg_atom_ind,
                    dtype=args.precision,
                    device=args.device,
                )
                l1_features = torch.ones(
                    *l2_features.shape[:2], 1, device=args.device, dtype=args.precision
                )
                cg_features = torch.cat([l1_features, l2_features], dim=-1)
            elif args.soln:
                cg_features = cg_proj.clone()  # Give solution
            elif args.cg_specific_atom:
                raise NotImplementedError()
                # # The purpose is to select the nearest atom to a cg_atom, project it, give a single atom that feature.
                # relative_xyz = geo.unsqueeze(2) - cg_xyz.unsqueeze(1)
                # nearest_atom_ind = relative_xyz.norm(dim=-1).argmin(1).squeeze()
                # cg_atom_ind = 2
                # l2_features = otp.project_atom_onto_cg_features(relative_xyz, args.cg_specific_atom, nearest_atom_ind[cg_atom_ind],
                #                                                 cg_atom_ind, dtype=args.precision, device=args.device)
                # l1_features = torch.ones(
                #     *l2_features.shape[:2], 1, device=args.device, dtype=args.precision
                # )
                # cg_features = torch.cat([l1_features, l2_features], dim=-1)
            else:
                cg_features = torch.zeros(
                    args.bs,
                    args.ncg,
                    args.ncg,
                    dtype=args.precision,
                    device=args.device,
                )
                cg_features.scatter_(
                    -1,
                    torch.arange(args.ncg, device=args.device)
                    .expand(args.bs, args.ncg)
                    .unsqueeze(-1),
                    1.0,
                )

            pred_sph = decoder(cg_features, cg_xyz.clone().detach())

            # Loss
            loss_ae_equi = (
                (cg_proj - pred_sph).pow(2).sum(-1).div(args.atomic_nums).mean()
            )
            if args.fm and epoch >= args.fm_epoch:
                # Force matching
                cg_force_assign, _ = gumbel_softmax(
                    logits,
                    temp_sched[epoch] * args.force_temp_coeff,
                    device=args.device,
                    dtype=args.precision,
                )
                cg_force = torch.einsum("...ij,zik->zjk", cg_force_assign, force)
                loss_fm = cg_force.pow(2).sum(-1).mean()

                # loss = loss_ae_equi + loss_ae_dense + args.fm_co * loss_fm
            else:
                loss_fm = torch.tensor(0)
                # loss = loss_ae_equi + loss_ae_dense
            loss = loss_ae_equi

            dynamics.append(
                {
                    "loss_ae_equi": loss_ae_equi.item(),
                    "loss_ae_dense": loss_ae_dense.item(),
                    "loss_fm": loss_fm.item(),
                    "loss": loss.item(),
                    "epoch": epoch,
                    "step": step,
                    "batch": batch.item(),
                }
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        wall = perf_counter() - wall_start
        if wall > args.wall:
            break

        summaries.append(
            {
                "loss_ae_equi": loss_ae_equi.item(),
                "loss_ae_dense": loss_ae_dense.item(),
                "loss_fm": loss_fm.item(),
                "loss": loss.item(),
                "epoch": epoch,
                "step": step,
                "batch": batch.item(),
                "cg_xyz": cg_xyz,
                "pred_sph": pred_sph,
                "sph": cg_proj,
                "temp": temp_sched[epoch].item(),
                "gumble": cg_assign,
                "st_gumble": st_cg_assign,
                "nearest": nearest_assign,
            }
        )

    return {
        "args": args,
        "dynamics": dynamics,
        "summaries": summaries,
        # 'train': {
        #     'pred': evaluate(f, features, geometry, train[:len(test)]),
        #     'true': forces[train[:len(test)]],
        # },
        # 'test': {
        #     'pred': evaluate(f, features, geometry, test[:len(train)]),
        #     'true': forces[test[:len(train)]],
        # },
        # 'encoder': encoder.state_dict() if args.save_state else None,
        "decoder": decoder.state_dict() if args.save_state else None,
    }
Ejemplo n.º 3
0
def execute(args):
    geometries, forces, features = otp.data(args)

    encoder = Encoder(args).to(device=args.device)
    decoder = Decoder(in_dim=args.ncg,
                      out_dim=geometries.size(1)).to(args.device)
    optimizer = torch.optim.Adam(list(encoder.parameters()) +
                                 list(decoder.parameters()),
                                 lr=args.lr)
    temp_sched = temp_scheduler(
        args.epochs,
        args.tdr,
        args.temp,
        args.tmin,
        dtype=args.precision,
        device=args.device,
    )
    n_batches, geometries, forces, features = otp.batch(
        geometries, forces, features, args.bs)

    dynamics = []
    summaries = []
    wall_start = perf_counter()
    torch.manual_seed(args.seed)
    for epoch in tqdm(range(args.epochs)):
        for step, batch in tqdm(
                enumerate(torch.randperm(n_batches, device=args.device))):
            feat, geo, force = features[batch], geometries[batch], forces[
                batch]

            # Auto encoder
            logits = encoder(feat, geo)
            cg_assign, st_cg_assign = gumbel_softmax(logits,
                                                     temp_sched[epoch],
                                                     dtype=args.precision,
                                                     device=args.device)
            E = cg_assign / cg_assign.sum(1).unsqueeze(1)
            cg_xyz = torch.einsum("zij,zik->zjk", E, geo)
            decoded = decoder(cg_xyz)
            loss_ae = (decoded - geo).pow(2).sum(-1).mean()

            if args.fm and epoch >= args.fm_epoch:
                # Force matching
                cg_force_assign, _ = gumbel_softmax(
                    logits,
                    temp_sched[epoch] * args.force_temp_coeff,
                    device=args.device,
                    dtype=args.precision,
                )
                cg_force = torch.einsum("zij,zik->zjk", cg_force_assign, force)
                loss_fm = cg_force.pow(2).sum(-1).mean()
                loss = loss_ae + args.fm_co * loss_fm
            else:
                loss_fm = torch.tensor(0)
                loss = loss_ae

            dynamics.append({
                "loss_ae": loss_ae.item(),
                "loss_fm": loss_fm.item(),
                "loss": loss.item(),
                "epoch": epoch,
                "step": step,
                "batch": batch.item(),
            })

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        wall = perf_counter() - wall_start
        if wall > args.wall:
            break

        summaries.append({
            "loss_ae": loss_ae.item(),
            "loss_fm": loss_fm.item(),
            "loss": loss.item(),
            "epoch": epoch,
            "step": step,
            "batch": batch.item(),
            "cg_xyz": cg_xyz,
            "temp": temp_sched[epoch].item(),
            "gumble": cg_assign,
            "st_gumble": st_cg_assign,
            "reconstructed": decoded,
        })

    return {
        "args": args,
        "dynamics": dynamics,
        "summaries": summaries,
        # 'train': {
        #     'pred': evaluate(f, features, geometry, train[:len(test)]),
        #     'true': forces[train[:len(test)]],
        # },
        # 'test': {
        #     'pred': evaluate(f, features, geometry, test[:len(train)]),
        #     'true': forces[test[:len(train)]],
        # },
        "encoder": encoder.state_dict() if args.save_state else None,
        "decoder": decoder.state_dict() if args.save_state else None,
    }