from pathlib import Path

import os

from pursuit.agents.handcoded.eps_random import EpsRandomAgent
from pursuit.agents.handcoded.greedy import GreedyAgent
from pursuit.agents.handcoded.random import RandomAgent
from pursuit.agents.handcoded.teammate_aware import TeammateAwareAgent
from utils import save_run

path = Path('datasets')
os.makedirs(str(path), exist_ok=True)
#
# save_run(path / '5x5_greedy', 10000, [GreedyAgent(i) for i in range(4)], world_size=(5, 5))
# save_run(path / '5x5_greedy_random', 10000, [GreedyAgent(i) for i in range(3)] + [EpsRandomAgent(3, RandomAgent(3), 0.5)], world_size=(5, 5))
# save_run(path / '5x5_ta', 10000, [TeammateAwareAgent(i) for i in range(4)], world_size=(5, 5))

# save_run(path / '10x10_greedy', 10000, [GreedyAgent(i) for i in range(4)], world_size=(10, 10))
save_run(path / '10x10_greedy_random',
         1000, [GreedyAgent(i)
                for i in range(3)] + [EpsRandomAgent(3, RandomAgent(3), 0.25)],
         world_size=(10, 10))
# save_run(path / '10x10_ta', 10000, [TeammateAwareAgent(i) for i in range(4)], world_size=(10, 10))

# save_run(path / '20x20_greedy', 10000, [GreedyAgent(i) for i in range(4)], world_size=(20, 20))
save_run(path / '20x20_greedy_random',
         1000, [GreedyAgent(i)
                for i in range(3)] + [EpsRandomAgent(3, RandomAgent(3), 0.1)],
         world_size=(20, 20))
# save_run(path / '20x20_ta', 10000, [TeammateAwareAgent(i) for i in range(4)], world_size=(20, 20))
Example #2
0
    def train(self, train_loader, val_loader, h: int, w: int, parser):
        """
        Train coarse and fine model on training data and run validation

        Parameters
        ----------
        train_loader : training data loader object.
        val_loader : validation data loader object.
        h : int
            height of images.
        w : int
            width of images.
        """
        args = self.args
        iter_per_epoch = len(train_loader)

        print('START TRAIN.')

        for epoch in range(
                args.num_epochs):  # loop over the dataset multiple times
            self.model_coarse.train()
            self.model_fine.train()
            train_loss = 0
            for i, data in enumerate(train_loader):
                for j, element in enumerate(data):
                    data[j] = element.to(self.device)
                rgb_truth = data[-1]

                rgb, rgb_fine, ray_samples, densities = self.pipeline(data)

                self.optim.zero_grad()

                loss = self.nerf_loss(rgb, rgb_fine, rgb_truth)
                loss.backward()
                self.optim.step()

                loss_item = loss.item()
                if i % args.log_iterations == args.log_iterations - 1:
                    print('[Epoch %d, Iteration %5d/%5d] TRAIN loss: %.7f' %
                          (epoch + 1, i + 1, iter_per_epoch, loss_item))
                    if args.early_validation:
                        self.model_coarse.eval()
                        self.model_fine.eval()
                        val_loss = 0
                        for j, data in enumerate(val_loader):
                            for j, element in enumerate(data):
                                data[j] = element.to(self.device)
                            rgb_truth = data[-1]

                            rgb, rgb_fine, _, _ = self.pipeline(data)

                            loss = self.nerf_loss(rgb, rgb_fine, rgb_truth)
                            val_loss += loss.item()
                        self.writer.add_scalars(
                            'Loss curve every nth iteration', {
                                'train loss': loss_item,
                                'val loss': val_loss / len(val_loader)
                            }, i // args.log_iterations + epoch *
                            (iter_per_epoch // args.log_iterations))
                train_loss += loss_item
            print('[Epoch %d] Average loss of Epoch: %.7f' %
                  (epoch + 1, train_loss / iter_per_epoch))

            self.model_coarse.eval()
            self.model_fine.eval()
            val_loss = 0
            rerender_images = []
            samples = []
            ground_truth_images = []
            densities_list = []
            image_counter = 0
            for i, data in enumerate(val_loader):
                for j, element in enumerate(data):
                    data[j] = element.to(self.device)
                rgb_truth = data[-1]
                with torch.no_grad():
                    rgb, rgb_fine, ray_samples, densities = self.pipeline(data)

                    loss = self.nerf_loss(rgb, rgb_fine, rgb_truth)
                    val_loss += loss.item()

                    ground_truth_images.append(
                        rgb_truth.detach().cpu().numpy())
                    rerender_images.append(rgb_fine.detach().cpu().numpy())
                    samples.append(ray_samples.detach().cpu().numpy())
                    densities_list.append(densities.detach().cpu().numpy())
                    if np.concatenate(densities_list).shape[0] >= (h * w):
                        while np.concatenate(densities_list).shape[0] >= (h *
                                                                          w):
                            densities_list = np.concatenate(densities_list)
                            image_densities = densities_list[:h *
                                                             w].reshape(-1)
                            densities_list = [densities_list[h * w:]]
                            samples = np.concatenate(samples)
                            image_samples = samples[:h * w].reshape(-1, 3)
                            samples = [samples[h * w:]]
                            vedo_data(self.writer,
                                      image_densities,
                                      image_samples,
                                      image_warps=None,
                                      epoch=epoch + 1,
                                      image_idx=image_counter)
                            image_counter += 1
            if len(val_loader) != 0:
                rerender_images = np.concatenate(rerender_images, 0).reshape(
                    (-1, h, w, 3))
                ground_truth_images = np.concatenate(
                    ground_truth_images).reshape((-1, h, w, 3))

            tensorboard_rerenders(self.writer,
                                  args.number_validation_images,
                                  rerender_images,
                                  ground_truth_images,
                                  step=epoch,
                                  ray_warps=None)

            print('[Epoch %d] VAL loss: %.7f' %
                  (epoch + 1, val_loss /
                   (len(val_loader) or not len(val_loader))))
            self.writer.add_scalars(
                'Loss Curve', {
                    'train loss': train_loss / iter_per_epoch,
                    'val loss': val_loss /
                    (len(val_loader) or not len(val_loader))
                }, epoch)

            save_run(self.writer.log_dir, [self.model_coarse, self.model_fine],
                     ['model_coarse.pt', 'model_fine.pt'], parser)
        print('FINISH.')
Example #3
0
def train():
    parser = config_parser()
    args = parser.parse_args()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    args.default_device = device
    if args.model_type not in [
            "nerf", "smpl_nerf", "append_to_nerf", "smpl", "warp",
            'vertex_sphere', "smpl_estimator", "original_nerf",
            'dummy_dynamic', 'image_wise_dynamic',
            "append_vertex_locations_to_nerf", 'append_smpl_params'
    ]:
        raise Exception("The model type ", args.model_type, " does not exist.")

    transform = transforms.Compose([
        NormalizeRGB(),
        CoarseSampling(args.near, args.far, args.number_coarse_samples),
        ToTensor()
    ])

    train_dir = os.path.join(args.dataset_dir, 'train')
    val_dir = os.path.join(args.dataset_dir, 'val')
    if args.model_type == "nerf":
        train_data = RaysFromImagesDataset(
            train_dir, os.path.join(train_dir, 'transforms.json'), transform)
        val_data = RaysFromImagesDataset(
            val_dir, os.path.join(val_dir, 'transforms.json'), transform)
    elif args.model_type == "smpl" or args.model_type == "warp":
        train_data = SmplDataset(train_dir,
                                 os.path.join(train_dir, 'transforms.json'),
                                 args,
                                 transform=NormalizeRGB())
        val_data = SmplDataset(val_dir,
                               os.path.join(val_dir, 'transforms.json'),
                               args,
                               transform=NormalizeRGB())
    elif args.model_type == "smpl_nerf" or args.model_type == "append_to_nerf" or args.model_type == "append_smpl_params":
        train_data = SmplNerfDataset(
            train_dir, os.path.join(train_dir, 'transforms.json'), transform)
        val_data = SmplNerfDataset(val_dir,
                                   os.path.join(val_dir, 'transforms.json'),
                                   transform)
    elif args.model_type == "vertex_sphere":
        train_data = VertexSphereDataset(
            train_dir, os.path.join(train_dir, 'transforms.json'), args)
        val_data = VertexSphereDataset(
            val_dir, os.path.join(val_dir, 'transforms.json'), args)
    elif args.model_type == "smpl_estimator":
        transform = NormalizeRGBImage()
        train_data = SmplEstimatorDataset(
            train_dir, os.path.join(train_dir, 'transforms.json'),
            args.vertex_sphere_radius, transform)
        val_data = SmplEstimatorDataset(
            val_dir, os.path.join(val_dir, 'transforms.json'),
            args.vertex_sphere_radius, transform)
    elif args.model_type == "original_nerf":
        train_data = OriginalNerfDataset(
            args.dataset_dir,
            os.path.join(args.dataset_dir, 'transforms_train.json'), transform)
        val_data = OriginalNerfDataset(
            args.dataset_dir,
            os.path.join(args.dataset_dir, 'transforms_val.json'), transform)
    elif args.model_type == "dummy_dynamic":
        train_data = DummyDynamicDataset(
            train_dir, os.path.join(train_dir, 'transforms.json'), transform)
        val_data = DummyDynamicDataset(
            val_dir, os.path.join(val_dir, 'transforms.json'), transform)
    elif args.model_type == "append_vertex_locations_to_nerf":
        train_data = DummyDynamicDataset(
            train_dir, os.path.join(train_dir, 'transforms.json'), transform)
        val_data = DummyDynamicDataset(
            val_dir, os.path.join(val_dir, 'transforms.json'), transform)
    elif args.model_type == 'image_wise_dynamic':
        canonical_pose1 = torch.zeros(38).view(1, -1)
        canonical_pose2 = torch.zeros(2).view(1, -1)
        canonical_pose3 = torch.zeros(27).view(1, -1)
        arm_angle_l = torch.tensor([np.deg2rad(10)]).float().view(1, -1)
        arm_angle_r = torch.tensor([np.deg2rad(10)]).float().view(1, -1)
        smpl_estimator = DummyImageWiseEstimator(canonical_pose1,
                                                 canonical_pose2,
                                                 canonical_pose3, arm_angle_l,
                                                 arm_angle_r,
                                                 torch.zeros(10).view(1, -1),
                                                 torch.zeros(69).view(1, -1))
        train_data = ImageWiseDataset(
            train_dir, os.path.join(train_dir, 'transforms.json'),
            smpl_estimator, transform, args)
        val_data = ImageWiseDataset(val_dir,
                                    os.path.join(val_dir, 'transforms.json'),
                                    smpl_estimator, transform, args)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batchsize,
                                               shuffle=True,
                                               num_workers=0)
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batchsize_val,
                                             shuffle=False,
                                             num_workers=0)
    position_encoder = PositionalEncoder(args.number_frequencies_postitional,
                                         args.use_identity_positional)
    direction_encoder = PositionalEncoder(args.number_frequencies_directional,
                                          args.use_identity_directional)
    model_coarse = RenderRayNet(args.netdepth,
                                args.netwidth,
                                position_encoder.output_dim * 3,
                                direction_encoder.output_dim * 3,
                                skips=args.skips)
    model_fine = RenderRayNet(args.netdepth_fine,
                              args.netwidth_fine,
                              position_encoder.output_dim * 3,
                              direction_encoder.output_dim * 3,
                              skips=args.skips_fine)

    if args.model_type == "smpl_nerf":
        human_pose_encoder = PositionalEncoder(args.number_frequencies_pose,
                                               args.use_identity_pose)
        positions_dim = position_encoder.output_dim if args.human_pose_encoding else 1
        human_pose_dim = human_pose_encoder.output_dim if args.human_pose_encoding else 1
        model_warp_field = WarpFieldNet(args.netdepth_warp, args.netwidth_warp,
                                        positions_dim * 3, human_pose_dim * 2)

        solver = SmplNerfSolver(model_coarse, model_fine, model_warp_field,
                                position_encoder, direction_encoder,
                                human_pose_encoder, train_data.canonical_smpl,
                                args, torch.optim.Adam, torch.nn.MSELoss())
        solver.train(train_loader, val_loader, train_data.h, train_data.w)

        save_run(solver.writer.log_dir,
                 [model_coarse, model_fine, model_warp_field],
                 ['model_coarse.pt', 'model_fine.pt', 'model_warp_field.pt'],
                 parser)

    elif args.model_type == 'smpl':
        solver = SmplSolver(model_coarse, model_fine, position_encoder,
                            direction_encoder, args, torch.optim.Adam,
                            torch.nn.MSELoss())
        solver.train(train_loader, val_loader, train_data.h, train_data.w,
                     parser)
        save_run(solver.writer.log_dir, [model_coarse, model_fine],
                 ['model_coarse.pt', 'model_fine.pt'], parser)

    elif args.model_type == 'nerf' or args.model_type == "original_nerf":
        solver = NerfSolver(model_coarse, model_fine, position_encoder,
                            direction_encoder, args, torch.optim.Adam,
                            torch.nn.MSELoss())
        solver.train(train_loader, val_loader, train_data.h, train_data.w,
                     parser)
        save_run(solver.writer.log_dir, [model_coarse, model_fine],
                 ['model_coarse.pt', 'model_fine.pt'], parser)

    elif args.model_type == 'warp':
        human_pose_encoder = PositionalEncoder(args.number_frequencies_pose,
                                               args.use_identity_pose)
        positions_dim = position_encoder.output_dim if args.human_pose_encoding else 1
        human_pose_dim = human_pose_encoder.output_dim if args.human_pose_encoding else 1
        model_warp_field = WarpFieldNet(args.netdepth_warp, args.netwidth_warp,
                                        positions_dim * 3, human_pose_dim * 2)
        human_pose_encoder = PositionalEncoder(args.number_frequencies_pose,
                                               args.use_identity_pose)
        solver = WarpSolver(model_warp_field, position_encoder,
                            direction_encoder, human_pose_encoder, args)
        solver.train(train_loader, val_loader, train_data.h, train_data.w)
        save_run(solver.writer.log_dir, [model_warp_field],
                 ['model_warp_field.pt'], parser)
    elif args.model_type == 'append_smpl_params':
        human_pose_encoder = PositionalEncoder(args.number_frequencies_pose,
                                               args.use_identity_pose)
        human_pose_dim = human_pose_encoder.output_dim if args.human_pose_encoding else 1

        model_coarse = RenderRayNet(
            args.netdepth,
            args.netwidth,
            position_encoder.output_dim * 3,
            direction_encoder.output_dim * 3,
            human_pose_dim * 69,
            skips=args.skips,
            use_directional_input=args.use_directional_input)
        model_fine = RenderRayNet(
            args.netdepth_fine,
            args.netwidth_fine,
            position_encoder.output_dim * 3,
            direction_encoder.output_dim * 3,
            human_pose_dim * 69,
            skips=args.skips_fine,
            use_directional_input=args.use_directional_input)

        if args.load_run is not None:
            model_coarse.load_state_dict(
                torch.load(os.path.join(args.load_run, 'model_coarse.pt'),
                           map_location=torch.device(device)))
            model_fine.load_state_dict(
                torch.load(os.path.join(args.load_run, 'model_fine.pt'),
                           map_location=torch.device(device)))
            print("Models loaded from ", args.load_run)
        if args.siren:
            model_coarse = SirenRenderRayNet(
                args.netdepth,
                args.netwidth,
                position_encoder.output_dim * 3,
                direction_encoder.output_dim * 3,
                human_pose_dim * 69,
                skips=args.skips,
                use_directional_input=args.use_directional_input)
            model_fine = SirenRenderRayNet(
                args.netdepth_fine,
                args.netwidth_fine,
                position_encoder.output_dim * 3,
                direction_encoder.output_dim * 3,
                human_pose_dim * 69,
                skips=args.skips_fine,
                use_directional_input=args.use_directional_input)
        solver = AppendSmplParamsSolver(model_coarse, model_fine,
                                        position_encoder, direction_encoder,
                                        human_pose_encoder, args,
                                        torch.optim.Adam, torch.nn.MSELoss())
        solver.train(train_loader, val_loader, train_data.h, train_data.w,
                     parser)

        save_run(solver.writer.log_dir, [model_coarse, model_fine],
                 ['model_coarse.pt', 'model_fine.pt'], parser)

        model_dependent = [human_pose_encoder, human_pose_dim]
        inference_gif(solver.writer.log_dir, args.model_type, args, train_data,
                      val_data, position_encoder, direction_encoder,
                      model_coarse, model_fine, model_dependent)
    elif args.model_type == 'append_to_nerf':
        human_pose_encoder = PositionalEncoder(args.number_frequencies_pose,
                                               args.use_identity_pose)
        human_pose_dim = human_pose_encoder.output_dim if args.human_pose_encoding else 1
        model_coarse = RenderRayNet(
            args.netdepth,
            args.netwidth,
            position_encoder.output_dim * 3,
            direction_encoder.output_dim * 3,
            human_pose_dim * 2,
            skips=args.skips,
            use_directional_input=args.use_directional_input)
        model_fine = RenderRayNet(
            args.netdepth_fine,
            args.netwidth_fine,
            position_encoder.output_dim * 3,
            direction_encoder.output_dim * 3,
            human_pose_dim * 2,
            skips=args.skips_fine,
            use_directional_input=args.use_directional_input)
        solver = AppendToNerfSolver(model_coarse, model_fine, position_encoder,
                                    direction_encoder, human_pose_encoder,
                                    args, torch.optim.Adam, torch.nn.MSELoss())
        solver.train(train_loader, val_loader, train_data.h, train_data.w,
                     parser)

        save_run(solver.writer.log_dir, [model_coarse, model_fine],
                 ['model_coarse.pt', 'model_fine.pt'], parser)

        model_dependent = [human_pose_encoder, human_pose_dim]
        inference_gif(solver.writer.log_dir, args.model_type, args, train_data,
                      val_data, position_encoder, direction_encoder,
                      model_coarse, model_fine, model_dependent)
    elif args.model_type == 'append_vertex_locations_to_nerf':
        model_coarse = AppendVerticesNet(args.netdepth,
                                         args.netwidth,
                                         position_encoder.output_dim * 3,
                                         direction_encoder.output_dim * 3,
                                         6890,
                                         additional_input_layers=1,
                                         skips=args.skips)
        model_fine = AppendVerticesNet(args.netdepth_fine,
                                       args.netwidth_fine,
                                       position_encoder.output_dim * 3,
                                       direction_encoder.output_dim * 3,
                                       6890,
                                       additional_input_layers=1,
                                       skips=args.skips_fine)
        smpl_estimator = DummySmplEstimatorModel(train_data.goal_poses,
                                                 train_data.betas)
        smpl_file_name = "SMPLs/smpl/models/basicModel_f_lbs_10_207_0_v1.0.0.pkl"
        smpl_model = smplx.create(smpl_file_name, model_type='smpl')
        smpl_model.batchsize = args.batchsize
        solver = AppendVerticesSolver(model_coarse, model_fine, smpl_estimator,
                                      smpl_model, position_encoder,
                                      direction_encoder, args,
                                      torch.optim.Adam, torch.nn.MSELoss())
        solver.train(train_loader, val_loader, train_data.h, train_data.w)

        save_run(solver.writer.log_dir, [model_coarse, model_fine],
                 ['model_coarse.pt', 'model_fine.pt'], parser)

    elif args.model_type == 'vertex_sphere':
        solver = VertexSphereSolver(model_coarse, model_fine, position_encoder,
                                    direction_encoder, args, torch.optim.Adam,
                                    torch.nn.MSELoss())
        solver.train(train_loader, val_loader, train_data.h, train_data.w)
        save_run(solver.writer.log_dir, [model_coarse, model_fine],
                 ['model_coarse.pt', 'model_fine.pt'], parser)

    elif args.model_type == 'smpl_estimator':

        model = SmplEstimator(human_size=len(args.human_joints))

        solver = SmplEstimatorSolver(model, args, torch.optim.Adam,
                                     torch.nn.MSELoss())
        solver.train(train_loader, val_loader)
        save_run(solver.writer.log_dir, [model], ['model_smpl_estimator.pt'],
                 parser)
    elif args.model_type == "dummy_dynamic":
        smpl_file_name = "SMPLs/smpl/models/basicModel_f_lbs_10_207_0_v1.0.0.pkl"
        smpl_model = smplx.create(smpl_file_name, model_type='smpl')
        smpl_model.batchsize = args.batchsize
        smpl_estimator = DummySmplEstimatorModel(train_data.goal_poses,
                                                 train_data.betas)
        solver = DynamicSolver(model_fine, model_coarse, smpl_estimator,
                               smpl_model, position_encoder, direction_encoder,
                               args)
        solver.train(train_loader, val_loader, train_data.h, train_data.w)
        save_run(solver.writer.log_dir,
                 [model_coarse, model_fine, smpl_estimator],
                 ['model_coarse.pt', 'model_fine.pt', 'smpl_estimator.pt'],
                 parser)
    elif args.model_type == "image_wise_dynamic":
        if args.load_coarse_model != None:
            print("Load model..")
            model_coarse.load_state_dict(
                torch.load(args.load_coarse_model,
                           map_location=torch.device(device)))
            for params in model_coarse.parameters():
                params.requires_grad = False
            model_coarse.eval()
        train_loader = torch.utils.data.DataLoader(train_data,
                                                   batch_size=1,
                                                   shuffle=True,
                                                   num_workers=0)
        val_loader = torch.utils.data.DataLoader(val_data,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=0)
        smpl_file_name = "SMPLs/smpl/models/basicModel_f_lbs_10_207_0_v1.0.0.pkl"
        smpl_model = smplx.create(smpl_file_name, model_type='smpl')
        smpl_model.batchsize = args.batchsize
        solver = ImageWiseSolver(model_coarse, model_fine, smpl_estimator,
                                 smpl_model, position_encoder,
                                 direction_encoder, args)
        solver.train(train_loader, val_loader, train_data.h, train_data.w)
        save_run(solver.writer.log_dir,
                 [model_coarse, model_fine, smpl_estimator],
                 ['model_coarse.pt', 'model_fine.pt', 'smpl_estimator.pt'],
                 parser)
Example #4
0
def main():
    # model_name = "mne-interp-mc"
    # model_name = "vq-2-mc"
    model_name = "nn"
    # model_name = "cvq-2-mc"
    # model_name = HOME_PATH + "reconstruction/saved_models/" + "vq-2-mc"

    if check_valid_filename(model_name):
        # the model name is filepath to the model
        saved_model = True
    else:
        saved_model = False

    z_dim = 30
    lr = 1e-3
    # sched = 'cycle'
    sched = None

    num_epochs = 200
    batch_size = 64
    num_examples_train = -1
    num_examples_eval = -1

    device = 'cuda'
    # normalize = True
    log_interval = 1
    tb_save_loc = "runs/testing/"
    select_channels = [0]  #[0,1,2,3]
    num_channels = len(select_channels)

    lengths = {
        # Single Channel Outputs
        "nn": 784,
        "cnn": 784,
        "vq": 784,
        "vq-2": 1024,
        "unet": 1024,

        # Multichannel Outputs
        "cnn-mc": 784,
        "vq-2-mc": 1024,
        "cvq-2-mc": 1023,

        # Baselines
        "avg-interp": 784,
        "mne-interp": 784,
        "mne-interp-mc": 1023,
    }

    models = {
        # "nn" : cVAE1c(z_dim=z_dim),
        "nn":
        VAE1c(z_dim=z_dim),
        "vq-2":
        cVQVAE_2(in_channel=1),
        "unet":
        UNet(in_channels=num_channels),
        "vq-2-mc":
        VQVAE_2(in_channel=num_channels),
        "cvq-2-mc":
        cVQVAE_2(in_channel=num_channels),
        "cnn-mc":
        ConvVAE(num_channels=num_channels,
                num_channels_out=num_channels,
                z_dim=z_dim),
        "avg-interp":
        AvgInterpolation(),
        "mne-interp":
        MNEInterpolation(),
        "mne-interp-mc":
        MNEInterpolationMC(),
    }

    model_filenames = {
        "nn": HOME_PATH + "models/VAE1c.py",
        "cnn": HOME_PATH + "models/conv_VAE.py",
        "vq": HOME_PATH + "models/VQ_VAE_1c.py",
        "vq-2": HOME_PATH + "models/vq-vae-2-pytorch/vqvae.py",
        "unet": HOME_PATH + "models/unet.py",
        "cnn-mc": HOME_PATH + "models/conv_VAE.py",
        "vq-2-mc": HOME_PATH + "models/vq-vae-2-pytorch/vqvae.py",
        "mne-interp-mc": HOME_PATH + "denoise/fill_baseline_models.py",
    }

    if saved_model:
        model = torch.load(model_name)
        length = 1024  # TODO find way to set auto
    else:
        model = models[model_name]
        length = lengths[model_name]

    if model_name == "mne-interp" or model_name == "mne-interp-mc":
        select_channels = [0, 1, 2, 3]
    # else:
    # select_channels = [0,1,2]
    # select_channels = [0,1]#,2,3]

    train_files = TRAIN_NORMAL_FILES_CSV  #TRAIN_FILES_CSV
    # train_files =  TRAIN_FILES_CSV

    eval_files = DEV_NORMAL_FILES_CSV  #DEV_FILES_CSV
    # eval_files =  DEV_FILES_CSV
    eval_dataset = EEGDatasetMc(eval_files,
                                max_num_examples=num_examples_eval,
                                length=length,
                                normalize=normalize,
                                select_channels=select_channels)
    eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)

    target_filename = model_name
    run_filename = find_valid_filename(target_filename,
                                       HOME_PATH + 'denoise/' + tb_save_loc)
    tb_filename = tb_save_loc + run_filename
    writer = SummaryWriter(tb_save_loc + run_filename)

    model = model.to(device)

    try:
        optimizer = optim.Adam(model.parameters(), lr=lr)
        train_model = True
    except ValueError:
        print("This Model Cannot Be Optimized")
        train_model = False
        sched = None

    if train_model:
        train_dataset = EEGDatasetMc(train_files,
                                     max_num_examples=num_examples_train,
                                     length=length,
                                     normalize=normalize,
                                     select_channels=select_channels)
        train_loader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True)
        print("m Train Dataset", len(train_dataset))

    print("m Eval Dataset", len(eval_dataset))

    if saved_model:
        train_model = True

    scheduler = None

    if sched == 'cycle':
        scheduler = CycleScheduler(optimizer,
                                   lr,
                                   n_iter=len(train_loader) * num_epochs,
                                   momentum=None)
    for i in range(1, num_epochs + 1):
        if train_model:
            train(i,
                  train_loader,
                  model,
                  optimizer,
                  scheduler,
                  device,
                  writer,
                  log_interval=log_interval)
        eval(i, eval_loader, model, device, writer, log_interval=log_interval)

    save_dir = HOME_PATH + "denoise/saved_runs/" + str(int(time.time())) + "/"
    recon_file = HOME_PATH + "denoise/fill_1c.py"
    train_file = HOME_PATH + "denoise/train_fill_1c.py"
    model_filename = model_filenames[model_name]
    python_files = [recon_file, train_file, model_filename]

    info_dict = {
        "model_name": model_name,
        "z_dim": z_dim,
        "lr": lr,
        "sched": sched,
        "num_epochs": num_epochs,
        "batch_size": batch_size,
        "num_examples_train": num_examples_train,
        "num_examples_eval": num_examples_eval,
        "train_files": train_files,
        "eval_files": eval_files,
        "device": device,
        "normalize": normalize.__name__,
        "log_interval": log_interval,
        "tb_dirpath": tb_filename
    }

    save_run(save_dir, python_files, model, info_dict)
    for key, value in info_dict.items():
        print(key + ":", value)
Example #5
0
    filenames = []
    for lr in random_lrs:
        print('lr', lr)
        cur_filename_info = str(lr) + "-" + str(num_epochs) + "-" + str(
            int(time.time()))
        cur_filename = filename + "-" + cur_filename_info
        filenames += [cur_filename]
        cur_g_filename = g_filename + "-" + cur_filename_info
        cur_d_filename = d_filename + "-" + cur_filename_info
        discriminator, generator = train_gan(discriminator,
                                             generator,
                                             train_loader,
                                             num_epochs,
                                             batch_size,
                                             lr,
                                             lr,
                                             dtype,
                                             save_images=False)
        fake_images = []
        for i in range(16):
            fake_images += [generator(generate_noise(4))]
        inception_score = get_inception_score(fake_images)
        print("inception score", inception_score)
        stats = save_run(inception_score, lr, num_epochs, discriminator,
                         generator, cur_filename, cur_g_filename,
                         cur_d_filename)
        run_stats += [stats]
    print(run_stats)
    purge_poor_runs([], "./saved_runs/", purge_all=False)
    print("training finished")