Exemple #1
0
    def run(self, run_mode):
        if run_mode == 'train':
            self.empty_log(self.__C.VERSION)
            train_engine(self.__C, self.dataset, self.dataset_eval)

        elif run_mode == 'val':
            test_engine(self.__C, self.dataset, validation=True)

        elif run_mode == 'test':
            test_engine(self.__C, self.dataset)

        else:
            exit(-1)
def main():
    N = args.num_trajectories  #number of generated trajectories
    model = PECNet(hyperparams["enc_past_size"], hyperparams["enc_dest_size"],
                   hyperparams["enc_latent_size"], hyperparams["dec_size"],
                   hyperparams["predictor_hidden_size"],
                   hyperparams["non_local_theta_size"],
                   hyperparams["non_local_phi_size"],
                   hyperparams["non_local_g_size"], hyperparams["fdim"],
                   hyperparams["zdim"], hyperparams["nonlocal_pools"],
                   hyperparams["non_local_dim"], hyperparams["sigma"],
                   hyperparams["past_length"], hyperparams["future_length"],
                   args.verbose)
    model = model.double().to(device)
    model.load_state_dict(checkpoint["model_state_dict"])
    test_dataset = SocialDataset(set_name="test",
                                 b_size=hyperparams["test_b_size"],
                                 t_tresh=hyperparams["time_thresh"],
                                 d_tresh=hyperparams["dist_thresh"],
                                 verbose=args.verbose)

    for traj in test_dataset.trajectory_batches:
        traj -= traj[:, :1, :]
        traj *= hyperparams["data_scale"]

    #average ade/fde for k=20 (to account for variance in sampling)
    num_samples = 150
    test_error = defaultdict(lambda: 0)

    for _ in range(num_samples):
        test_error_dict = test_engine(test_dataset,
                                      model,
                                      device,
                                      hyperparams,
                                      best_of_n=N)
        test_error["ade"] += test_error_dict["ade"]
        test_error["fde"] += test_error_dict["fde"]

    for key in test_error:
        print(f"Average {key} = {test_error[key] / num_samples}")
Exemple #3
0
		init_wandb(hyperparams.copy(), model, args)
	
	# Only for k_variation experiment load the pretrained model and run it
	if args.experiment == "k_variation":
		checkpoint = torch.load(f"../saved_models/PECNET_social_model1.pt", map_location=device)
		model.load_state_dict(checkpoint["model_state_dict"])
		train_dataset = SocialDataset(set_name="train", b_size=hyperparams["train_b_size"], t_tresh=hyperparams["time_thresh"], d_tresh=hyperparams["dist_thresh"], verbose=args.verbose)
		test_dataset = SocialDataset(set_name="test", b_size=hyperparams["test_b_size"], t_tresh=hyperparams["time_thresh"], d_tresh=hyperparams["dist_thresh"], verbose=args.verbose)
		# shift origin and scale data
		for traj in train_dataset.trajectory_batches:
			traj -= traj[:, 0:1, :]
			traj *= hyperparams["data_scale"]
		for traj in test_dataset.trajectory_batches:
			traj -= traj[:, 0:1, :]
			traj *= hyperparams["data_scale"]
		test_error_dict = test_engine(args.dataset, test_dataset, model, device, hyperparams, best_of_n = hyperparams['n_values'], experiment = args.experiment)
		print("Best ADE :" + test_error_dict['ade'])
		print("Best FDE :" + test_error_dict['fde'])
		return


	# initialize optimizer
	# optimizer = optim.Adam(model.parameters(), lr=hyperparams["learning_rate"])

	# initialize dataloaders
	if args.dataset == "drone":
		train_dataset = SocialDataset(set_name="train", b_size=hyperparams["train_b_size"], t_tresh=hyperparams["time_thresh"], d_tresh=hyperparams["dist_thresh"], verbose=args.verbose)
		test_dataset = SocialDataset(set_name="test", b_size=hyperparams["test_b_size"], t_tresh=hyperparams["time_thresh"], d_tresh=hyperparams["dist_thresh"], verbose=args.verbose)
		# shift origin and scale data
		for traj in train_dataset.trajectory_batches:
			traj -= traj[:, hyperparams["past_length"]:hyperparams["past_length"]+1, :]
Exemple #4
0
def train_engine(__C, dataset, dataset_eval=None):

    data_size = dataset.data_size
    token_size = dataset.token_size
    ans_size = dataset.ans_size
    pretrained_emb = dataset.pretrained_emb

    net = ModelLoader(__C).Net(__C, pretrained_emb, token_size, ans_size)
    net.cuda()
    net.train()

    if __C.N_GPU > 1:
        net = nn.DataParallel(net, device_ids=__C.DEVICES)

    # Binary cross entropy loss
    if __C.DATASET in ['gqa', 'clevr']:
        loss_fn = torch.nn.CrossEntropyLoss(reduction='sum').cuda()
    else:
        loss_fn = torch.nn.BCELoss(reduction='sum').cuda()

    # Load checkpoint if resume training
    if __C.RESUME:
        print(' ========== Resume training')

        if __C.CKPT_PATH is not None:
            print('Warning: Now using CKPT_PATH args, '
                  'CKPT_VERSION and CKPT_EPOCH will not work')

            path = __C.CKPT_PATH
        else:
            path = __C.CKPTS_PATH + \
                   '/ckpt_' + __C.CKPT_VERSION + \
                   '/epoch' + str(__C.CKPT_EPOCH) + '.pkl'

        # Load the network parameters
        print('Loading ckpt from {}'.format(path))
        ckpt = torch.load(path)
        print('Finish!')
        net.load_state_dict(ckpt['state_dict'])
        start_epoch = ckpt['epoch']

        # Load the optimizer paramters
        optim = get_optim(__C, net, data_size, ckpt['lr_base'])
        optim._step = int(data_size / __C.BATCH_SIZE * start_epoch)
        optim.optimizer.load_state_dict(ckpt['optimizer'])

    else:
        if ('ckpt_' + __C.VERSION) in os.listdir(__C.CKPTS_PATH):
            shutil.rmtree(__C.CKPTS_PATH + '/ckpt_' + __C.VERSION)

        os.mkdir(__C.CKPTS_PATH + '/ckpt_' + __C.VERSION)

        optim = get_optim(__C, net, data_size)
        start_epoch = 0

    loss_sum = 0
    named_params = list(net.named_parameters())
    grad_norm = np.zeros(len(named_params))

    # Define multi-thread dataloader
    # if __C.SHUFFLE_MODE in ['external']:
    #     dataloader = Data.DataLoader(
    #         dataset,
    #         batch_size=__C.BATCH_SIZE,
    #         shuffle=False,
    #         num_workers=__C.NUM_WORKERS,
    #         pin_memory=__C.PIN_MEM,
    #         drop_last=True
    #     )
    # else:
    dataloader = Data.DataLoader(dataset,
                                 batch_size=__C.BATCH_SIZE,
                                 shuffle=True,
                                 num_workers=__C.NUM_WORKERS,
                                 pin_memory=__C.PIN_MEM,
                                 drop_last=True)

    # Training script
    for epoch in range(start_epoch, __C.MAX_EPOCH):

        # Save log to file
        logfile = open(__C.LOG_PATH + '/log_run_' + __C.VERSION + '.txt', 'a+')
        logfile.write('nowTime: ' +
                      datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') +
                      '\n')
        logfile.close()

        # Learning Rate Decay
        if epoch in __C.LR_DECAY_LIST:
            adjust_lr(optim, __C.LR_DECAY_R)

        # Externally shuffle data list
        # if __C.SHUFFLE_MODE == 'external':
        #     dataset.shuffle_list(dataset.ans_list)

        time_start = time.time()
        # Iteration
        for step, (frcn_feat_iter, grid_feat_iter, bbox_feat_iter,
                   ques_ix_iter, ans_iter) in enumerate(dataloader):

            optim.zero_grad()

            frcn_feat_iter = frcn_feat_iter.cuda()
            grid_feat_iter = grid_feat_iter.cuda()
            bbox_feat_iter = bbox_feat_iter.cuda()
            ques_ix_iter = ques_ix_iter.cuda()
            ans_iter = ans_iter.cuda()

            for accu_step in range(__C.GRAD_ACCU_STEPS):

                sub_frcn_feat_iter = \
                    frcn_feat_iter[accu_step * __C.SUB_BATCH_SIZE:
                                  (accu_step + 1) * __C.SUB_BATCH_SIZE]
                sub_grid_feat_iter = \
                    grid_feat_iter[accu_step * __C.SUB_BATCH_SIZE:
                                  (accu_step + 1) * __C.SUB_BATCH_SIZE]
                sub_bbox_feat_iter = \
                    bbox_feat_iter[accu_step * __C.SUB_BATCH_SIZE:
                                  (accu_step + 1) * __C.SUB_BATCH_SIZE]
                sub_ques_ix_iter = \
                    ques_ix_iter[accu_step * __C.SUB_BATCH_SIZE:
                                 (accu_step + 1) * __C.SUB_BATCH_SIZE]
                sub_ans_iter = \
                    ans_iter[accu_step * __C.SUB_BATCH_SIZE:
                             (accu_step + 1) * __C.SUB_BATCH_SIZE]

                pred = net(sub_frcn_feat_iter, sub_grid_feat_iter,
                           sub_bbox_feat_iter, sub_ques_ix_iter)

                if __C.DATASET in ['gqa', 'clevr']:
                    loss = loss_fn(pred, sub_ans_iter.view(-1))
                else:
                    loss = loss_fn(torch.sigmoid(pred), sub_ans_iter)

                loss /= __C.GRAD_ACCU_STEPS
                loss.backward()
                loss_sum += loss.cpu().data.numpy() * __C.GRAD_ACCU_STEPS

                if __C.VERBOSE:
                    if dataset_eval is not None:
                        mode_str = __C.SPLIT['train'] + '->' + __C.SPLIT['val']
                    else:
                        mode_str = __C.SPLIT['train'] + '->' + __C.SPLIT['test']

                    print(
                        "\r[Version %s][Model %s][Dataset %s][Epoch %2d][Step %4d/%4d][%s] Loss: %.4f, Lr: %.2e"
                        % (__C.VERSION, __C.MODEL_USE, __C.DATASET, epoch + 1,
                           step, int(data_size / __C.BATCH_SIZE), mode_str,
                           loss.cpu().data.numpy() / __C.SUB_BATCH_SIZE,
                           optim._rate),
                        end='          ')

            # Gradient norm clipping
            if __C.GRAD_NORM_CLIP > 0:
                nn.utils.clip_grad_norm_(net.parameters(), __C.GRAD_NORM_CLIP)

            # Save the gradient information
            for name in range(len(named_params)):
                norm_v = torch.norm(named_params[name][1].grad).cpu().data.numpy() \
                    if named_params[name][1].grad is not None else 0
                grad_norm[name] += norm_v * __C.GRAD_ACCU_STEPS
                # print('Param %-3s Name %-80s Grad_Norm %-20s'%
                #       (str(grad_wt),
                #        params[grad_wt][0],
                #        str(norm_v)))

            optim.step()

        time_end = time.time()
        print('Finished in {}s'.format(int(time_end - time_start)))
        #print('')
        epoch_finish = epoch + 1

        # Save checkpoint
        state = {
            'state_dict': net.state_dict(),
            'optimizer': optim.optimizer.state_dict(),
            'lr_base': optim.lr_base,
            'epoch': epoch_finish
        }
        torch.save(
            state, __C.CKPTS_PATH + '/ckpt_' + __C.VERSION + '/epoch' +
            str(epoch_finish) + '.pkl')

        # Logging
        logfile = open(__C.LOG_PATH + '/log_run_' + __C.VERSION + '.txt', 'a+')
        logfile.write('epoch = ' + str(epoch_finish) + '  loss = ' +
                      str(loss_sum / data_size) + '\n' + 'lr = ' +
                      str(optim._rate) + '\n\n')
        logfile.close()

        # Eval after every epoch
        if dataset_eval is not None:
            test_engine(__C,
                        dataset_eval,
                        state_dict=net.state_dict(),
                        validation=True)

        # if self.__C.VERBOSE:
        #     logfile = open(
        #         self.__C.LOG_PATH +
        #         '/log_run_' + self.__C.VERSION + '.txt',
        #         'a+'
        #     )
        #     for name in range(len(named_params)):
        #         logfile.write(
        #             'Param %-3s Name %-80s Grad_Norm %-25s\n' % (
        #                 str(name),
        #                 named_params[name][0],
        #                 str(grad_norm[name] / data_size * self.__C.BATCH_SIZE)
        #             )
        #         )
        #     logfile.write('\n')
        #     logfile.close()

        loss_sum = 0
        grad_norm = np.zeros(len(named_params))
Exemple #5
0
def train_engine(__C, dataset, dataset_eval=None):

    data_size = dataset.data_size
    token_size = dataset.token_size
    ans_size = dataset.ans_size
    pretrained_emb = dataset.pretrained_emb

    net = ModelLoader(__C).Net(
        __C,
        pretrained_emb,
        token_size,
        ans_size
    )
    net.cuda()
    net.train()

    if __C.N_GPU > 1:
        net = nn.DataParallel(net, device_ids=__C.DEVICES)

    # Define Loss Function
    loss_fn = eval('torch.nn.' + __C.LOSS_FUNC_NAME_DICT[__C.LOSS_FUNC] + "(reduction='" + __C.LOSS_REDUCTION + "').cuda()")

    # Load checkpoint if resume training
    if __C.RESUME:
        print(' ========== Resume training')

        if __C.CKPT_PATH is not None:
            print('Warning: Now using CKPT_PATH args, '
                  'CKPT_VERSION and CKPT_EPOCH will not work')

            path = __C.CKPT_PATH
        else:
            path = __C.CKPTS_PATH + \
                   '/ckpt_' + __C.CKPT_VERSION + \
                   '/epoch' + str(__C.CKPT_EPOCH) + '.pkl'

        # Load the network parameters
        print('Loading ckpt from {}'.format(path))
        ckpt = torch.load(path)
        print('Finish!')

        if __C.N_GPU > 1:
            net.load_state_dict(ckpt_proc(ckpt['state_dict']))
        else:
            net.load_state_dict(ckpt['state_dict'])
        start_epoch = ckpt['epoch']

        # Load the optimizer paramters
        optim = get_optim(__C, net, data_size, ckpt['lr_base'])
        optim._step = int(data_size / __C.BATCH_SIZE * start_epoch)
        optim.optimizer.load_state_dict(ckpt['optimizer'])
        
        if ('ckpt_' + __C.VERSION) not in os.listdir(__C.CKPTS_PATH):
            os.mkdir(__C.CKPTS_PATH + '/ckpt_' + __C.VERSION)

    else:
        if ('ckpt_' + __C.VERSION) not in os.listdir(__C.CKPTS_PATH):
            #shutil.rmtree(__C.CKPTS_PATH + '/ckpt_' + __C.VERSION)
            os.mkdir(__C.CKPTS_PATH + '/ckpt_' + __C.VERSION)

        optim = get_optim(__C, net, data_size)
        start_epoch = 0

    loss_sum = 0
    named_params = list(net.named_parameters())
    grad_norm = np.zeros(len(named_params))

    # Define multi-thread dataloader
    # if __C.SHUFFLE_MODE in ['external']:
    #     dataloader = Data.DataLoader(
    #         dataset,
    #         batch_size=__C.BATCH_SIZE,
    #         shuffle=False,
    #         num_workers=__C.NUM_WORKERS,
    #         pin_memory=__C.PIN_MEM,
    #         drop_last=True
    #     )
    # else:
    dataloader = Data.DataLoader(
        dataset,
        batch_size=__C.BATCH_SIZE,
        shuffle=True,
        num_workers=__C.NUM_WORKERS,
        pin_memory=__C.PIN_MEM,
        drop_last=True,
        multiprocessing_context='spawn'
    )

    logfile = open(
        __C.LOG_PATH +
        '/log_run_' + __C.VERSION + '.txt',
        'a+'
    )
    logfile.write(str(__C))
    logfile.close()

    # Training script
    for epoch in range(start_epoch, __C.MAX_EPOCH):

        # Save log to file
        logfile = open(
            __C.LOG_PATH +
            '/log_run_' + __C.VERSION + '.txt',
            'a+'
        )
        logfile.write(
            '=====================================\nnowTime: ' +
            datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') +
            '\n'
        )
        logfile.close()

        # Learning Rate Decay
        if epoch in __C.LR_DECAY_LIST:
            adjust_lr(optim, __C.LR_DECAY_R)

        # Externally shuffle data list
        # if __C.SHUFFLE_MODE == 'external':
        #     dataset.shuffle_list(dataset.ans_list)

        time_start = time.time()
        # Iteration
        for step, (
                frcn_feat_iter,
                grid_feat_iter,
                bbox_feat_iter,
                ques_ix_iter,
                ans_iter
        ) in enumerate(dataloader):

            optim.zero_grad()

            frcn_feat_iter = frcn_feat_iter.cuda()
            grid_feat_iter = grid_feat_iter.cuda()
            bbox_feat_iter = bbox_feat_iter.cuda()
            ques_ix_iter = ques_ix_iter.cuda()
            ans_iter = ans_iter.cuda()

            loss_tmp = 0
            for accu_step in range(__C.GRAD_ACCU_STEPS):
                loss_tmp = 0

                sub_frcn_feat_iter = \
                    frcn_feat_iter[accu_step * __C.SUB_BATCH_SIZE:
                                  (accu_step + 1) * __C.SUB_BATCH_SIZE]
                sub_grid_feat_iter = \
                    grid_feat_iter[accu_step * __C.SUB_BATCH_SIZE:
                                  (accu_step + 1) * __C.SUB_BATCH_SIZE]
                sub_bbox_feat_iter = \
                    bbox_feat_iter[accu_step * __C.SUB_BATCH_SIZE:
                                  (accu_step + 1) * __C.SUB_BATCH_SIZE]
                sub_ques_ix_iter = \
                    ques_ix_iter[accu_step * __C.SUB_BATCH_SIZE:
                                 (accu_step + 1) * __C.SUB_BATCH_SIZE]
                sub_ans_iter = \
                    ans_iter[accu_step * __C.SUB_BATCH_SIZE:
                             (accu_step + 1) * __C.SUB_BATCH_SIZE]

                pred = net(
                    sub_frcn_feat_iter,
                    sub_grid_feat_iter,
                    sub_bbox_feat_iter,
                    sub_ques_ix_iter
                )

                loss_item = [pred, sub_ans_iter]
                loss_nonlinear_list = __C.LOSS_FUNC_NONLINEAR[__C.LOSS_FUNC]
                for item_ix, loss_nonlinear in enumerate(loss_nonlinear_list):
                    if loss_nonlinear in ['flat']:
                        loss_item[item_ix] = loss_item[item_ix].view(-1)
                    elif loss_nonlinear:
                        loss_item[item_ix] = eval('F.' + loss_nonlinear + '(loss_item[item_ix], dim=1)')

                loss = loss_fn(loss_item[0], loss_item[1])
                if __C.LOSS_REDUCTION == 'mean':
                    # only mean-reduction needs be divided by grad_accu_steps
                    loss /= __C.GRAD_ACCU_STEPS
                loss.backward()

                loss_tmp += loss.cpu().data.numpy() * __C.GRAD_ACCU_STEPS
                loss_sum += loss.cpu().data.numpy() * __C.GRAD_ACCU_STEPS

            if __C.VERBOSE:
                if dataset_eval is not None:
                    mode_str = __C.SPLIT['train'] + '->' + __C.SPLIT['val']
                else:
                    mode_str = __C.SPLIT['train'] + '->' + __C.SPLIT['test']

                # print("\r[Version %s][Model %s][Dataset %s][Epoch %2d][Step %4d/%4d][%s] Loss: %.4f, Lr: %.2e" % (
                #     __C.VERSION,
                #     __C.MODEL_USE,
                #     __C.DATASET,
                #     epoch + 1,
                #     step,
                #     int(data_size / __C.BATCH_SIZE),
                #     mode_str,
                #     loss_tmp / __C.SUB_BATCH_SIZE,
                #     optim._rate
                # ), end='          ')
                print("\r[Time Passed: %s][Version %s][Model %s][Dataset %s][Epoch %2d][Step %4d/%4d][%s] Loss: %.4f, Lr: %.2e" % (
                    time.strftime('%H:%M:%S', time.gmtime(time.time() - time_start)),
                    __C.VERSION,
                    __C.MODEL_USE,
                    __C.DATASET,
                    epoch + 1,
                    step,
                    int(data_size / __C.BATCH_SIZE),
                    mode_str,
                    loss_tmp / __C.SUB_BATCH_SIZE,
                    optim._rate
                ), end='          ')

            # Gradient norm clipping
            if __C.GRAD_NORM_CLIP > 0:
                nn.utils.clip_grad_norm_(
                    net.parameters(),
                    __C.GRAD_NORM_CLIP
                )

            # Save the gradient information
            for name in range(len(named_params)):
                norm_v = torch.norm(named_params[name][1].grad).cpu().data.numpy() \
                    if named_params[name][1].grad is not None else 0
                grad_norm[name] += norm_v * __C.GRAD_ACCU_STEPS
                # print('Param %-3s Name %-80s Grad_Norm %-20s'%
                #       (str(grad_wt),
                #        params[grad_wt][0],
                #        str(norm_v)))

            optim.step()

        time_end = time.time()
        elapse_time = time_end-time_start
        print('Finished in {}s'.format(int(elapse_time)))
        epoch_finish = epoch + 1

        # Save checkpoint
        if __C.N_GPU > 1:
            state = {
                'state_dict': net.module.state_dict(),
                'optimizer': optim.optimizer.state_dict(),
                'lr_base': optim.lr_base,
                'epoch': epoch_finish
            }
        else:
            state = {
                'state_dict': net.state_dict(),
                'optimizer': optim.optimizer.state_dict(),
                'lr_base': optim.lr_base,
                'epoch': epoch_finish
            }
        torch.save(
            state,
            __C.CKPTS_PATH +
            '/ckpt_' + __C.VERSION +
            '/epoch' + str(epoch_finish) +
            '.pkl'
        )

        drive_path = '/content/drive/My Drive/thesis/last_checkpoint'
        if os.path.exists(drive_path):
            clear_dir(drive_path)
            os.mkdir(drive_path+ '/ckpt_' + __C.VERSION)
            torch.save(
                state,
                drive_path + 
                '/ckpt_' + __C.VERSION +
                '/epoch' + str(epoch_finish) +
                '.pkl'
            )

        # Logging
        logfile = open(
            __C.LOG_PATH +
            '/log_run_' + __C.VERSION + '.txt',
            'a+'
        )
        logfile.write(
            'Epoch: ' + str(epoch_finish) +
            ', Loss: ' + str(loss_sum / data_size) +
            ', Lr: ' + str(optim._rate) + '\n' +
            'Elapsed time: ' + str(int(elapse_time)) + 
            ', Speed(s/batch): ' + str(elapse_time / step) +
            '\n\n'
        )
        logfile.close()

        # Eval after every epoch
        if dataset_eval is not None:
            test_engine(
                __C,
                dataset_eval,
                state_dict=net.state_dict(),
                validation=True
            )

        # if self.__C.VERBOSE:
        #     logfile = open(
        #         self.__C.LOG_PATH +
        #         '/log_run_' + self.__C.VERSION + '.txt',
        #         'a+'
        #     )
        #     for name in range(len(named_params)):
        #         logfile.write(
        #             'Param %-3s Name %-80s Grad_Norm %-25s\n' % (
        #                 str(name),
        #                 named_params[name][0],
        #                 str(grad_norm[name] / data_size * self.__C.BATCH_SIZE)
        #             )
        #         )
        #     logfile.write('\n')
        #     logfile.close()

        loss_sum = 0
        grad_norm = np.zeros(len(named_params))
                                  b_size=hyperparams["test_b_size"],
                                  t_tresh=hyperparams["time_thresh"],
                                  d_tresh=hyperparams["dist_thresh"],
                                  verbose=args.verbose)

    best_ade = 50  # start saving after this threshold
    best_fde = 50
    best_metrics = {}
    N = hyperparams["n_values"]

    for e in range(hyperparams["num_epochs"]):
        train_loss_dict = train_engine(args.dataset, train_dataset, model,
                                       device, hyperparams, optimizer)
        test_error_dict = test_engine(args.dataset,
                                      test_dataset,
                                      model,
                                      device,
                                      hyperparams,
                                      best_of_n=N)

        if test_error_dict["ade"] < best_ade:
            best_ade = test_error_dict["ade"]
            best_metrics["best_ade"] = (best_ade, e)
            if best_ade < 10.25:
                save_path = "../saved_models/" + args.version + ".pt"
                torch.save(
                    {
                        "hyperparams": hyperparams,
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict()
                    }, save_path)
                if args.wandb: