def __init__(self, args):
        self.args = args
        self.config, self.output_dir, self.logger, self.device = common.init_experiment(
            args, seed=420)
        self.train_loader, self.train_data, self.val_data = data_utils.get_dataloaders(
            args['root'], self.config['dataloader'],
            self.config['model']['horizon'], self.device)
        self.train_data = self.train_data.view(-1, )
        self.val_data = self.val_data.view(-1, )

        # Models, optimizer and scheduler
        self.model = networks.Forecaster(self.config['model']).to(self.device)
        self.optim = train_utils.get_optimizer(config=self.config['optim'],
                                               params=self.model.parameters())
        self.scheduler, self.warmup_epochs = train_utils.get_scheduler(
            config={
                **self.config['scheduler'], 'epochs': self.config['epochs']
            },
            optimizer=self.optim)

        # Count model parameters
        total_params = common.count_parameters([self.model])
        if total_params // 1e06 > 0:
            self.logger.record(
                f'Total trainable parameters: {round(total_params/1e06, 2)}M',
                mode='info')
        else:
            self.logger.record(f'Total trainable parameters: {total_params}',
                               mode='info')

        # Criterion and logging
        self.criterion = nn.MSELoss()
        self.best_val = 1e09
        self.best_mape = 10.0
        self.done_epochs = 0
        run = wandb.init('stock-prediction-nbeats')
        self.logger.write(run.get_url(), mode='info')

        # Warmup handling
        if self.warmup_epochs > 0:
            self.warmup_rate = self.optim.param_groups[0][
                'lr'] / self.warmup_epochs

        # Load best model if specified
        if args['load'] is not None:
            if os.path.exists(os.path.join(args['load'], 'best_model.ckpt')):
                self.load_model(args['load'])
                self.logger.record(
                    f"Successfully loaded saved model from {args['load']}",
                    mode='info')
            else:
                raise NotImplementedError(
                    f"Could not load best_model.ckpt from {args['load']}; please check your path",
                    mode='info')

        self.horizon = self.config['model']['horizon']
        self.lookback = self.config['model'][
            'lookback_horizon_ratio'] * self.horizon
예제 #2
0
    def __init__(self, args):
        self.args = args
        self.config, self.output_dir, self.logger, self.device = common.init_experiment(
            args)

        # Initiate model, optimizer and scheduler
        assert self.config['model']['name'] in NETWORKS.keys(
        ), f"Unrecognized model name {self.config['model']['name']}"
        self.model = NETWORKS[self.config['model']['name']]['net'](
            pretrained=self.config['model']['pretrained']).to(self.device)
        self.optim = train_utils.get_optimizer(self.config['optimizer'],
                                               self.model.parameters())
        self.scheduler, self.warmup_epochs = train_utils.get_scheduler(
            {
                **self.config['scheduler'], "epochs": self.config["epochs"]
            }, self.optim)

        if self.warmup_epochs > 0:
            self.warmup_rate = (self.config['optimizer']['lr'] -
                                1e-12) / self.warmup_epochs

        # Dataloaders
        self.train_loader, self.val_loader, self.test_loader = data_utils.get_dataloaders(
            train_root=self.config['data']['train_root'],
            test_root=self.config['data']['test_root'],
            transforms=self.config['data']['transforms'],
            val_split=self.config['data']['val_split'],
            batch_size=self.config['data']['batch_size'])
        self.beta_dist = beta.Beta(self.config['data'].get("alpha", 0.3),
                                   self.config['data'].get("alpha", 0.3))
        self.batch_size = self.config['data']['batch_size']

        # Logging and model saving
        self.criterion = losses.LogLoss()
        self.best_val_loss = np.inf
        self.done_epochs = 0

        # Wandb
        run = wandb.init(project='deepfake-dl-hack')
        self.logger.write(f"Wandb: {run.get_url()}", mode='info')

        # Load model
        if args['load'] is not None:
            self.load_model(args['load'])
예제 #3
0
def run(args):
    args.device = torch.device(
        'cuda' if args.use_cuda and torch.cuda.is_available() else 'cpu')
    print("Using device:", args.device)

    dataset, train_val_test_ratio = data_utils.get_graph_dataset(
        args.dataset_name, destination_dir=args.data_folder)

    cv_test_stats = []
    if args.test_emb:
        cv_baselines_test_stats = ut.BaselinesCVStats()
    for cv_fold in range(args.folds):
        print(f"Cross Validation Fold: {cv_fold}")
        dataset = dataset.shuffle()
        train_dataloader, val_dataloader, test_dataloader = data_utils.get_dataloaders(
            dataset,
            args.batch_size,
            "multi",
            train_val_test_ratio=[0.7, 0.1, 0.2],
            num_workers=1,
            shuffle_train=True)

        output_gc_dim = dataset.num_classes
        output_nc_dim = dataset[0].node_y.size(1)
        model = MultiTaskGCN(args.tasks,
                             dataset.num_node_features,
                             args.embedding_dim,
                             output_gc_dim,
                             output_nc_dim,
                             residual_con=args.residual_con,
                             normalize_emb=args.normalize_emb,
                             batch_norm=args.batch_norm,
                             dropout=args.dropout)
        model = model.to(args.device)

        train(model, train_dataloader, args, val_dataloader)

        if args.output_folder:
            if args.folds > 1:
                name = f"concurrent_multitask_gcn_cv_{cv_fold}"
            else:
                name = "concurrent_multitask_gcn"
            saved_dir = ut.save_model(model, args.output_folder, name, args)
            print("Model saved at path:", saved_dir)

        test_acc = test(model, test_dataloader, args)
        cv_test_stats.append(test_acc)

        if args.test_emb:
            if args.output_folder:
                folder_name = "baselines" if args.folds == 1 else f"baselines_cv_{cv_fold}"
                baselines_output_folder = os.path.join(args.output_folder,
                                                       folder_name)
            else:
                baselines_output_folder = None
            embedding_stats = test_embeddings.run_test(
                model,
                train_dataloader.dataset,
                val_dataloader.dataset,
                test_dataloader.dataset,
                epochs=100,
                batch_size=16,
                lr=1e-3,
                embedding_dim=args.embedding_dim,
                es_tmpdir=args.es_tmpdir,
                hidden_dim=args.embedding_dim,
                early_stopping=True,
                output_folder=baselines_output_folder,
                device=args.device)
            cv_baselines_test_stats.update(embedding_stats)

    print("\n\n############## Baseline Multitask GCN ##############")
    ut.print_cv_stats(cv_test_stats)
    if args.test_emb:
        cv_baselines_test_stats.print_stats()

    return cv_test_stats, model
예제 #4
0
파일: train.py 프로젝트: lilleswing/SAME
def run(args):
    args.device = torch.device(
        'cuda' if args.use_cuda and torch.cuda.is_available() else 'cpu')
    print("Using device:", args.device)

    dataset, train_val_test_ratio = data_utils.get_graph_dataset(
        args.dataset_name, destination_dir=args.data_folder)

    cv_test_stats = []
    cv_test_stats_best_loss = []
    if args.test_emb:
        cv_baselines_test_stats = BaselinesCVStats()
    for cv_fold in range(args.folds):
        print(f"Cross Validation Fold: {cv_fold}", flush=True)
        dataset = dataset.shuffle()
        meta_train_dataloader, meta_val_dataloader, meta_test_dataloader = data_utils.get_dataloaders(
            dataset,
            args.batch_size,
            args.batch_task,
            train_val_test_ratio=train_val_test_ratio,
            num_workers=1)

        if args.meta_alg == "MAML":
            model_func = MultitaskGCN
        elif args.meta_alg == "ANIL":
            model_func = MultitaskGCN_2
        model = model_func(dataset.num_node_features,
                           args.embedding_dim,
                           dataset[0].node_y.size(1),
                           dataset.num_classes,
                           residual_con=args.residual_con,
                           normalize_emb=args.normalize_emb,
                           batch_norm=args.batch_norm,
                           dropout=args.dropout)
        model = model.to(args.device)

        train_batch_task_list = val_batch_task_list = test_batch_task_list = None
        if args.batch_task == "single":
            train_batch_task_list = data_utils.create_batch_task_list(
                len(meta_train_dataloader), tasks=args.tasks)
            val_batch_task_list = data_utils.create_batch_task_list(
                len(meta_val_dataloader), tasks=args.tasks)
            test_batch_task_list = data_utils.create_batch_task_list(
                len(meta_test_dataloader), tasks=args.tasks)
            prepare_batch_tasks_func = data_utils.single_task_train_test_split
        elif args.batch_task == "multi":
            prepare_batch_tasks_func = partial(
                data_utils.multi_task_train_test_split, tasks=args.tasks)
        elif args.batch_task == "conc":
            prepare_batch_tasks_func = partial(
                data_utils.concurrent_multi_task_train_test_split,
                tasks=args.tasks)

        with open(os.devnull, "w") as f, contextlib.ExitStack() as gs:
            if args.folds > 1:
                gs.enter_context(contextlib.redirect_stdout(f))
                gs.enter_context(contextlib.redirect_stderr(f))
            global_training_stats = meta_train(model, meta_train_dataloader,
                                               prepare_batch_tasks_func, args,
                                               train_batch_task_list,
                                               meta_val_dataloader,
                                               val_batch_task_list)

        if args.create_training_plots:
            cv_filename_prefix = ""
            if args.folds > 1:
                cv_filename_prefix = str(cv_fold)
            ut.create_stats_plots(global_training_stats, cv_filename_prefix)

        if args.output_folder:
            if args.folds > 1:
                name = f"cv_{cv_fold}"
            else:
                name = "multitask_gcn"
            saved_dir = ut.save_model(model, args.output_folder, name, args)
            print("Model saved at path:", saved_dir)

        # For testing the batch size is always 6
        tasks_test_stats = meta_test(model, meta_test_dataloader,
                                     prepare_batch_tasks_func, args,
                                     test_batch_task_list)
        cv_test_stats.append(tasks_test_stats)

        ## Try also model with best val loss
        if args.early_stopping:
            model_best_loss = copy.deepcopy(model)
            ut.recover_early_stopping_best_weights(model_best_loss,
                                                   args.es_tmpdir,
                                                   name="best_val_loss")
            tasks_test_stats = meta_test(model_best_loss, meta_test_dataloader,
                                         prepare_batch_tasks_func, args,
                                         test_batch_task_list)
            cv_test_stats_best_loss.append(tasks_test_stats)

        if args.test_emb:
            if args.output_folder:
                folder_name = "baselines" if args.folds == 1 else f"baselines_cv_{cv_fold}"
                baselines_output_folder = os.path.join(args.output_folder,
                                                       folder_name)
            else:
                baselines_output_folder = None
            with open(os.devnull, "w") as f, contextlib.ExitStack() as gs:
                print("Run Test Embeddings")
                if args.folds > 1:
                    gs.enter_context(contextlib.redirect_stdout(f))
                    gs.enter_context(contextlib.redirect_stderr(f))
                embedding_stats = test_embeddings.run_test(
                    model,
                    meta_train_dataloader.dataset,
                    meta_val_dataloader.dataset,
                    meta_test_dataloader.dataset,
                    epochs=100,
                    batch_size=8,
                    lr=1e-3,
                    embedding_dim=args.embedding_dim,
                    hidden_dim=args.embedding_dim,
                    early_stopping=True,
                    es_tmpdir=args.es_tmpdir,
                    output_folder=baselines_output_folder,
                    device=args.device)
                cv_baselines_test_stats.update(embedding_stats)

    print("\n\n############## Meta-Learned Multitask GCN ##############")
    print("Best Val Acc")
    ut.print_cv_stats(cv_test_stats)
    if args.early_stopping:
        print("\nBest Val Loss")
        ut.print_cv_stats(cv_test_stats_best_loss)
    if args.test_emb:
        cv_baselines_test_stats.print_stats()

    return cv_test_stats, model
예제 #5
0
                    type=str,
                    choices=['base', 'iw', 'miw', 'ciw', 'betavae', 'piw'],
                    required=True)
parser.add_argument('--M', type=int, required=True)
parser.add_argument('--K', type=int, required=True)
parser.add_argument('--beta', type=float, required=False)
parser.add_argument('--ckpt_int', type=int, required=False)
args = parser.parse_args()

print_exp = "Experiment: {} {} {} ".format(args.model, args.M, args.K)
if args.beta:
    print_exp += str(args.beta)
print(print_exp, flush=True)

batch_size = 256
train_loader, val_loader = get_dataloaders(args.data, batch_size)

sine_model_args = {
    'obs_dim': 1,
    'rec_latent_dim': 8,
    'node_latent_dim': 4,
    'rec_gru_unit': 100,
    'rec_node_hidden': 100,
    'rec_node_layer': 2,
    'rec_node_act': 'Tanh',
    'latent_node_hidden': 100,
    'latent_node_layer': 2,
    'latent_node_act': 'Tanh',
    'dec_type': 'NN',
    'dec_hidden': 100,
    'dec_layer': 1,