def main():
    os.makedirs('/'.join(parameters['checkpoint_path'].split('/')[:-1]),
                exist_ok=True)
    os.makedirs('/'.join(parameters['best_model_path'].split('/')[:-1]),
                exist_ok=True)

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    print(f'Device : {device} is selected')
    file_names = [
        ''.join(basename(i)[:-4])
        for i in glob.glob(os.path.join(parameters['image_dir'], '*'))
    ]
    # print(file_names)
    train_loader, validation_loader = get_custom_dataset(
        file_names, parameters['train_frac'])

    model = pretrained_model('fasterrcnn_resnet50_fpn')
    model = model.to(device)

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params,
                                lr=parameters['learning_rate'],
                                momentum=parameters['momentum'],
                                weight_decay=parameters['weight_decay'])
    try:
        model, optimizer, start_epoch = load_ckp(parameters['checkpoint_path'],
                                                 model, optimizer)
    except Exception:
        print(
            'No Previous Checkpoint Found....Training the model from scratch')
        start_epoch = 0
    best_loss = 1e10
    train_fn(start_epoch, parameters['epoch'], train_loader, validation_loader,
             model, device, optimizer, best_loss,
             parameters['checkpoint_path'], parameters['best_model_path'])
示例#2
0
def run():
    '''
    Entire training loop
        - Create DataLoaders
        - Define Training Configuration
        - Launch Training Loop
    '''

    # Num of available TPU cores
    if config.TPUs:
        n_TPUs = xm.xrt_world_size()
        DEVICE = xm.xla_device()
    else:
        DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = torch.device(DEVICE)
    
    # Read Data
    
    # df1 = pd.read_csv('data/jigsaw-toxic-comment-train.csv', usecols=['comment_text', 'toxic'])
    # df2 = pd.read_csv('data/jigsaw-unintended-bias-train.csv', usecols=['comment_text', 'toxic'], engine='python') # don't know why it was breaking with default C parser
    # df_train = df1 # pd.concat([df1,df2], axis=0).reset_index(drop=True)
    # df_valid = pd.read_csv('data/validation.csv')
    
    # Subsample
    df_train = pd.read_csv('data/jigsaw-toxic-comment-train-small.csv', usecols=['comment_text', 'toxic'])
    df_valid = pd.read_csv('data/validation-small.csv', usecols=['comment_text', 'toxic']) 

    # Preprocess
    
    train_dataset = dataset.BERTDataset(
        comment=df_train.comment_text.values,
        target=df_train.toxic.values
    )

    valid_dataset = dataset.BERTDataset(
        comment=df_valid.comment_text.values,
        target=df_valid.toxic.values
    )

    drop_last=False
    train_sampler, valid_sampler = None, None
    if config.TPUs:
        drop_last=True
        train_sampler = DistributedSampler(
            train_dataset, 
            num_replicas=n_TPUs,
            rank=xm.get_ordinal(),
            shuffle=True
        )
        valid_sampler = DistributedSampler(
            valid_dataset, 
            num_replicas=n_TPUs,
            rank=xm.get_ordinal(),
            shuffle=True
        )


    # Create Data Loaders

    train_data_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=config.TRAIN_BATCH_SIZE,
        num_workers=4,
        drop_last=drop_last,
        sampler=train_sampler
    )


    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.VALID_BATCH_SIZE,
        num_workers=1,
        drop_last=drop_last,
        sampler=valid_sampler
    )

    # Machine Configuration

    if config.MODEL == 'bert':
        model = BERTBaseUncased()
    elif config.MODEL == 'distil-bert':
        model = DistilBERTBaseUncased()
    else:
        print('Model chosen in config not valid')
        exit()
    model.to(device)
    
    # Optimizer Configuration 

    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    ]

    lr = config.LR
    num_train_steps = int(len(df_train) / config.TRAIN_BATCH_SIZE * config.EPOCHS)
    # TODO: why do the LR increases because of a distributed training ?
    if config.TPUs:
        num_train_steps /= n_TPUs
        lr *= n_TPUs

    optimizer = AdamW(optimizer_parameters, lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=num_train_steps
    )

    if not config.TPUs:
        if N_GPU > 1:
            model = nn.DataParallel(model)
    
    # Training loop

    best_score = 0
    
    for epoch in range(config.EPOCHS):
    
        if config.TPUs:
            train_loader = pl.ParallelLoader(train_data_loader, [device])
            valid_loader = pl.ParallelLoader(valid_data_loader, [device])
            train_fn(train_loader.per_device_loader(device), model, optimizer, device, scheduler)
            outputs, targets = eval_fn(valid_loader.per_device_loader(device), model, device)

        else:
            train_fn(train_data_loader, model, optimizer, device, scheduler)
            outputs, targets = eval_fn(valid_data_loader, model, device)
        
        targets = np.array(targets) >= 0.5 # TODO: why ?
        auc_score = metrics.roc_auc_score(targets, outputs)
            
        # Save if best
        print(f"AUC Score = {auc_score}")
        if auc_score > best_score:
            if not config.TPUs:
                torch.save(model.state_dict(), config.MODEL_PATH)
            else:
                xm.save(model.state_dict(), config.MODEL_PATH)
            best_score = auc_score
示例#3
0
def main():
    dfx = pd.read_csv(config.TRAINING_FILE)

    df_train, df_valid = train_test_split(dfx, test_size=0.2, random_state=42)

    train_dataset = TweetDataset(
        tweet=df_train.text.values,
        sentiment=df_train.sentiment.values,
        selected_text=df_train.selected_text.values,
    )

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=config.TRAIN_BATCH_SIZE, num_workers=4)

    valid_dataset = TweetDataset(
        tweet=df_valid.text.values,
        sentiment=df_valid.sentiment.values,
        selected_text=df_valid.selected_text.values,
    )

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=config.VALID_BATCH_SIZE, num_workers=2)

    device = torch.device("cuda")
    model_config = transformers.RobertaConfig.from_pretrained(
        config.ROBERTA_PATH)
    model_config.output_hidden_states = True
    model = TweetModel(conf=model_config)
    model.to(device)

    num_train_steps = int(
        len(df_train) / config.TRAIN_BATCH_SIZE * config.EPOCHS)
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {
            "params": [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.001,
        },
        {
            "params":
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            "weight_decay":
            0.0,
        },
    ]
    optimizer = AdamW(optimizer_parameters, lr=3e-5)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=0, num_training_steps=num_train_steps)

    for _ in range(config.EPOCHS):
        train_fn(train_data_loader,
                 model,
                 optimizer,
                 device,
                 scheduler=scheduler)
        jaccard = eval_fn(valid_data_loader, model, device)
        print(f"Jaccard Score = {jaccard}")

    torch.save(model, "model.pth")
示例#4
0
    # Build the model by passing in the input params
    model_class = XLMRobertaForSequenceClassification
    model = Classifier(model_class,
                       args.model_name,
                       num_labels=args.num_labels,
                       output_attentions=False,
                       output_hidden_states=False)
    # Send the model to the device
    model.to(device)

    # Train the model
    train_losses, train_accuracies, validation_losses, validation_accuracies = train_fn(
        model,
        train_dataloader,
        validation_dataloader,
        args.epochs,
        args.lr,
        device,
        args.best_model_path,
        args.torch_manual_seed,
        freeze_pretrained_encoder=args.freeze_pretrained_encoder)
    print('Training complete')

    # Plot training and validation losses and accuracies for n_epochs
    # plot(args.epochs, train_losses, train_accuracies, validation_losses, validation_accuracies)

    # Get model predictions on test-set data
    test_input = data.encode(test_df,
                             tokenizer,
                             max_len=args.max_sequence_length,
                             testing=True)
    test_data = TensorDataset(test_input['input_word_ids'],
示例#5
0
def main():

    parser = argparse.ArgumentParser()

    # For grouping in wandb
    parser.add_argument('--model_type', type=str, default='mil')
    parser.add_argument('--experiment', type=str, default='basic')
    parser.add_argument('--job', type=str, default='train')
    parser.add_argument('--save_dir', type=str, default='../runs/debug')

    # Run parameters
    parser.add_argument('--n_epochs', type=int, default=10000)
    parser.add_argument('--early_stop', type=int, default=10)
    parser.add_argument('--print_freq', type=int, default=1)
    parser.add_argument('--eval_freq', type=int, default=10)
    parser.add_argument('--n_folds', type=int, default=0)
    parser.add_argument('--n_folds_done', type=int, default=0)
    parser.add_argument('--save_model', dest='save_model', action='store_true')
    parser.set_defaults(save_model=False)
    parser.add_argument('--test', dest='test', action='store_true')
    parser.set_defaults(test=False)
    parser.add_argument('--no_eval',
                        dest='eval_during_training',
                        action='store_false')
    parser.set_defaults(eval_during_training=True)

    # Model parameters
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--prep_n_layers', type=int, default=3)
    parser.add_argument('--after_n_layers', type=int, default=3)
    parser.add_argument('--hid_dim', type=int, default=256)
    parser.add_argument('--cluster_method', type=str, default='spectral')
    parser.add_argument('--pool_method', type=str, default='mean')

    # Dataset parameters
    parser.add_argument('--dataset', type=str, default='pdf')
    parser.add_argument('--debug_dataset', type=int, default=0)
    parser.add_argument('--n_sets_train', type=int, default=10000)
    parser.add_argument('--n_sets_val', type=int, default=1000)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--n_elements', type=int, default=100)
    parser.add_argument('--n_clusters_low', type=int, default=2)
    parser.add_argument('--n_clusters_high', type=int, default=6)
    parser.add_argument('--circles_dont_shuffle',
                        dest='circles_shuffle',
                        action='store_false')
    parser.set_defaults(circles_shuffle=True)

    # Model specific dataset parameters
    parser.add_argument('--pair_indicators',
                        dest='pair_indicators',
                        action='store_true')
    parser.set_defaults(pair_indicators=False)
    parser.add_argument('--stratify', dest='stratify', action='store_true')
    parser.set_defaults(stratify=False)
    parser.add_argument('--n_pairs', type=int, default=50)
    parser.add_argument('--pair_bag_len', type=int, default=0)

    args, _ = parser.parse_known_args()

    args.out_dim = 2

    if args.n_folds:
        for i in range(args.n_folds_done, args.n_folds):
            args_copy = copy.deepcopy(args)
            args_copy.fold_number = i

            if args.dataset == 'pdf':
                if args.pair_indicators:
                    args_copy.in_dim = 7
                else:
                    args_copy.in_dim = 18
                # Run training
                train_fn(args_copy, run_epoch, eval_fn, create_model,
                         pdf_create_dataloaders)
            elif args.dataset == 'circles':
                if args.pair_indicators:
                    args_copy.in_dim = 3
                else:
                    args_copy.in_dim = 6
                # Run training
                train_fn(args_copy, run_epoch, eval_fn, create_model,
                         circles_create_dataloaders)
            elif args.dataset == 'mog':
                if args.pair_indicators:
                    args_copy.in_dim = 3
                else:
                    args_copy.in_dim = 6
                # Run training
                train_fn(args_copy, run_epoch, eval_fn, create_model,
                         mog_create_dataloaders)

    else:
        if args.dataset == 'pdf':
            if args.pair_indicators:
                args.in_dim = 7
            else:
                args.in_dim = 18
            # Run training
            train_fn(args, run_epoch, eval_fn, create_model,
                     pdf_create_dataloaders)
        elif args.dataset == 'circles':
            if args.pair_indicators:
                args.in_dim = 3
            else:
                args.in_dim = 6
            # Run training
            train_fn(args, run_epoch, eval_fn, create_model,
                     circles_create_dataloaders)
        elif args.dataset == 'mog':
            if args.pair_indicators:
                args.in_dim = 3
            else:
                args.in_dim = 6
            # Run training
            train_fn(args, run_epoch, eval_fn, create_model,
                     mog_create_dataloaders)
示例#6
0
def main():

    parser = argparse.ArgumentParser()

    # For grouping in wandb
    parser.add_argument('--model_type', type=str, default='permequi')
    parser.add_argument('--experiment', type=str, default='basic')
    parser.add_argument('--job', type=str, default='train')
    parser.add_argument('--save_dir', type=str, default='../runs/debug')

    # Run parameters
    parser.add_argument('--n_epochs', type=int, default=10000)
    parser.add_argument(
        '--early_stop', type=int, default=20
    )  # There tend to be some oscillation in validation loss after 100th epoch or so.., need to keep early stop a little bit larger
    parser.add_argument('--print_freq', type=int, default=1)
    parser.add_argument('--eval_freq', type=int, default=10)
    parser.add_argument('--n_folds', type=int, default=0)
    parser.add_argument('--n_folds_done', type=int, default=0)
    parser.add_argument('--save_model', dest='save_model', action='store_true')
    parser.set_defaults(save_model=False)
    parser.add_argument('--test', dest='test', action='store_true')
    parser.set_defaults(test=False)
    parser.add_argument('--no_eval',
                        dest='eval_during_training',
                        action='store_false')
    parser.set_defaults(eval_during_training=True)

    # Model parameters
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--pme_version', type=str, default='max')
    parser.add_argument('--encoder_n_layers', type=int, default=3)
    parser.add_argument('--aencoder_n_layers', type=int, default=3)
    parser.add_argument('--decoder_n_layers', type=int, default=3)
    parser.add_argument('--hid_dim', type=int, default=256)
    parser.add_argument('--anchor_embed_method', type=str, default='cat')
    parser.add_argument('--pme_cluster', type=str, default='spectral')

    # Dataset parameters
    parser.add_argument('--dataset', type=str, default='circles')
    parser.add_argument('--debug_dataset', type=int, default=0)
    parser.add_argument('--n_sets_train', type=int, default=100000)
    parser.add_argument('--n_sets_val', type=int, default=1000)
    parser.add_argument('--n_elements', type=int, default=100)
    parser.add_argument('--n_clusters_low', type=int, default=2)
    parser.add_argument('--n_clusters_high', type=int, default=6)
    parser.add_argument('--circles_dont_shuffle',
                        dest='circles_shuffle',
                        action='store_false')
    parser.set_defaults(circles_shuffle=True)

    args, _ = parser.parse_known_args()

    if args.n_folds:
        for i in range(args.n_folds_done, args.n_folds):
            args_copy = copy.deepcopy(args)
            args_copy.fold_number = i

            if args.dataset == 'mog':
                args_copy.in_dim = 2
                args_copy.out_dim = 1
                train_fn(args_copy, artif_run_epoch, artif_eval_fn,
                         create_model, mog_create_dataloaders)
            elif args.dataset == 'circles':
                args_copy.in_dim = 2
                args_copy.out_dim = 1
                train_fn(args_copy, artif_run_epoch, artif_eval_fn,
                         create_model, circles_create_dataloaders)
            elif args.dataset == 'pdf':
                args_copy.in_dim = 6
                args_copy.out_dim = 1
                train_fn(args_copy, pdf_run_epoch, pdf_eval_fn, create_model,
                         pdf_create_dataloaders)
            else:
                print('Incorrect dataset')
                return -1

    else:

        if args.dataset == 'mog':
            args.in_dim = 2
            args.out_dim = 1
            train_fn(args, artif_run_epoch, artif_eval_fn, create_model,
                     mog_create_dataloaders)
        elif args.dataset == 'circles':
            args.in_dim = 2
            args.out_dim = 1
            train_fn(args, artif_run_epoch, artif_eval_fn, create_model,
                     circles_create_dataloaders)
        elif args.dataset == 'pdf':
            args.in_dim = 6
            args.out_dim = 1
            train_fn(args, pdf_run_epoch, pdf_eval_fn, create_model,
                     pdf_create_dataloaders)
        else:
            print('Incorrect dataset')
            return -1
示例#7
0
__author__ = 'TramAnh'

from train import train_fn
from test import test_fn

if __name__ == '__main__':
    hiddennodes = 6
    trainfile = '../Data/alignment/train_aligned.csv'
    testfile = '../Data/alignment/test_aligned.csv'

    model_file = 'Serialized/model_{0}_nodes.pkl'.format(str(hiddennodes))
    train_fn(trainfile, hiddennodes, model_file)
    print 'Done with training'
    test_fn(testfile, hiddennodes, model_file)
def main():
  
  parser = argparse.ArgumentParser()
  
  # For grouping in wandb
  parser.add_argument('--model_type', type=str, default='abc')
  parser.add_argument('--experiment', type=str, default='basic')
  parser.add_argument('--job', type=str, default='train')
  parser.add_argument('--save_dir', type=str, default='../runs/debug')
  
  # Run parameters
  parser.add_argument('--n_epochs', type=int, default=10000)
  parser.add_argument('--print_freq', type=int, default=1)
  parser.add_argument('--eval_freq', type=int, default=10)
  parser.add_argument('--early_stop', type=int, default=10)
  parser.add_argument('--n_folds', type=int, default=0)
  parser.add_argument('--n_folds_done', type=int, default=0)
  parser.add_argument('--save_model', dest='save_model', action='store_true')
  parser.set_defaults(save_model=False)
  parser.add_argument('--test', dest='test', action='store_true')
  parser.set_defaults(test=False)
  parser.add_argument('--no_eval', dest='eval_during_training', action='store_false')
  parser.set_defaults(eval_during_training=True)
  
  # Parameters
  parser.add_argument('--lr', type=float, default=1e-3)
  parser.add_argument('--batch_size', type=int, default=32)  
  parser.add_argument('--hidden_size', type=int, default=128)
  parser.add_argument('--n_enc_layers', type=int, default=2)
  parser.add_argument('--num_heads', type=int, default=1)
  parser.add_argument('--compat', type=str, default='multi')
  parser.add_argument('--isab', dest='isab', action='store_true')
  parser.set_defaults(isab=False)
  parser.add_argument('--cluster_method', type=str, default='spectral')
  
  # Dataset parameters
  parser.add_argument('--dataset', type=str, default='pdf')
  parser.add_argument('--debug_dataset', type=int, default=0)
  parser.add_argument('--n_sets_train', type=int, default=100000)
  parser.add_argument('--n_sets_val', type=int, default=1000)
  parser.add_argument('--n_elements', type=int, default=100)
  parser.add_argument('--n_clusters_low', type=int, default=2)
  parser.add_argument('--n_clusters_high', type=int, default=6)
  parser.add_argument('--circles_dont_shuffle', dest='circles_shuffle', action='store_false')
  parser.set_defaults(circles_shuffle=True)
  
  
  args, _ = parser.parse_known_args()
    
  if args.n_folds:
    
    for i in range(args.n_folds_done, args.n_folds):
      args_copy = copy.deepcopy(args)
      args_copy.fold_number = i
      
      if args.dataset == 'mog':
        args_copy.input_size=2
        train_fn(args_copy, artif_run_epoch, artif_eval_fn, create_model, mog_create_dataloaders)
      elif args.dataset == 'circles':
        args_copy.input_size = 2
        train_fn(args_copy, artif_run_epoch, artif_eval_fn, create_model, circles_create_dataloaders)
      elif args.dataset == 'pdf':
        args_copy.input_size = 6
        train_fn(args_copy, pdf_run_epoch, pdf_eval_fn, create_model, pdf_create_dataloaders)
      else:
        print('Incorrect dataset')
        return -1    
  else:
    
    if args.dataset == 'mog':
      args.input_size = 2
      train_fn(args, artif_run_epoch, artif_eval_fn, create_model, mog_create_dataloaders)
    elif args.dataset == 'circles':
      args.input_size = 2
      train_fn(args, artif_run_epoch, artif_eval_fn, create_model, circles_create_dataloaders)
    elif args.dataset == 'pdf':
      args.input_size = 6
      train_fn(args, pdf_run_epoch, pdf_eval_fn, create_model, pdf_create_dataloaders)
    else:
      print('Incorrect dataset')
      return -1    
def run():
    dfx = pd.read_csv(config.TRAINING_FILE).dropna().reset_index(drop=True)

    df_train, df_valid = model_selection.train_test_split(
        dfx, test_size=0.1, random_state=42, stratify=dfx.sentiment.values)

    df_train = df_train.reset_index(drop=True)
    df_valid = df_valid.reset_index(drop=True)

    train_dataset = TweetDataset(tweet=df_train.text.values,
                                 sentiment=df_train.sentiment.values,
                                 selected_text=df_train.selected_text.values)

    valid_dataset = TweetDataset(tweet=df_valid.text.values,
                                 sentiment=df_valid.sentiment.values,
                                 selected_text=df_valid.selected_text.values)

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=config.TRAIN_BATCH_SIZE, num_workers=4)

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=config.VALID_BATCH_SIZE, num_workers=1)

    device = torch.device("cuda")
    conf = transformers.RobertaConfig.from_pretrained(config.ROBERTA_PATH)
    model = TweetModel(conf)
    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.001
        },
        {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        },
    ]

    num_train_steps = int(
        len(df_train) / config.TRAIN_BATCH_SIZE * config.EPOCHS)
    optimizer = AdamW(optimizer_parameters, lr=3e-5)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=0, num_training_steps=num_train_steps)

    model = nn.DataParallel(model)

    best_jaccard = 0
    for epoch in range(config.EPOCHS):
        train_fn(train_data_loader, model, optimizer, device, scheduler)
        jaccard = eval_fn(valid_data_loader, model, device)
        print(f"Jaccard Score = {jaccard}")
        if jaccard > best_jaccard:
            torch.save(model.state_dict(), config.MODEL_PATH)
            best_jaccard = jaccard
示例#10
0
def main():

    parser = argparse.ArgumentParser()

    # For grouping in wandb
    parser.add_argument('--model_type', type=str, default='dac')
    parser.add_argument('--experiment', type=str, default='basic')
    parser.add_argument('--job', type=str, default='train')
    parser.add_argument('--save_dir', type=str, default='../runs/debug')

    # Run parameters
    parser.add_argument('--n_epochs', type=int, default=10000)
    parser.add_argument('--print_freq', type=int, default=1)
    parser.add_argument('--eval_freq', type=int, default=10)
    parser.add_argument('--early_stop', type=int, default=10)
    parser.add_argument('--n_folds', type=int, default=0)
    parser.add_argument('--n_folds_done', type=int, default=0)
    parser.add_argument('--save_model', dest='save_model', action='store_true')
    parser.set_defaults(save_model=False)
    parser.add_argument('--test', dest='test', action='store_true')
    parser.set_defaults(test=False)
    parser.add_argument('--no_eval',
                        dest='eval_during_training',
                        action='store_false')
    parser.set_defaults(eval_during_training=True)

    # Model parameters
    parser.add_argument('--lr', type=float, default=1e-04)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--dim', type=int, default=256)
    parser.add_argument('--num_blocks', type=int, default=8)
    parser.add_argument('--num_inds', type=int, default=32)
    parser.add_argument('--num_heads', type=int, default=4)
    parser.add_argument('--no_isab2', dest='use_isab2', action='store_false')
    parser.set_defaults(use_isab2=True)
    parser.add_argument('--drop_p', type=float, default=0)
    parser.add_argument('--layer_norm', dest='ln', action='store_true')
    parser.set_defaults(ln=False)
    parser.add_argument('--max_iter', type=int, default=30)

    # Dataset parameters
    parser.add_argument('--dataset', type=str, default='pdf')
    parser.add_argument('--debug_dataset', type=int, default=0)

    args, _ = parser.parse_known_args()

    if args.dataset == 'circles':
        parser.add_argument('--n_sets_train', type=int, default=100000)
        parser.add_argument('--n_sets_val', type=int, default=1000)
        parser.add_argument('--n_elements', type=int, default=100)
        parser.add_argument('--n_clusters_low', type=int, default=2)
        parser.add_argument('--n_clusters_high', type=int, default=6)

    elif args.dataset == 'mog':
        parser.add_argument('--n_sets_train', type=int, default=100000)
        parser.add_argument('--n_sets_val', type=int, default=1000)
        parser.add_argument('--n_elements', type=int, default=100)
        parser.add_argument('--n_clusters_low', type=int, default=2)
        parser.add_argument('--n_clusters_high', type=int, default=6)

    else:
        pass

    parser.add_argument('--circles_dont_shuffle',
                        dest='circles_shuffle',
                        action='store_false')
    parser.set_defaults(circles_shuffle=True)
    parser.add_argument('--augment_pdf_data',
                        dest='augment_pdf_data',
                        action='store_true')
    parser.set_defaults(augment_pdf_data=False)

    args, _ = parser.parse_known_args()

    if args.drop_p == 0:
        args.drop_p = None

    if args.n_folds:

        for i in range(args.n_folds_done, args.n_folds):
            args_copy = copy.deepcopy(args)
            args_copy.fold_number = i

            if args.dataset == 'mog':
                args_copy.input_size = 2
                train_fn(args_copy, artif_run_epoch, artif_eval_fn,
                         create_model, mog_create_dataloaders)
            elif args.dataset == 'circles':
                args_copy.input_size = 2
                train_fn(args_copy, artif_run_epoch, artif_eval_fn,
                         create_model, circles_create_dataloaders)
            elif args.dataset == 'pdf':
                args_copy.input_size = 6
                train_fn(args_copy, pdf_run_epoch, pdf_eval_fn, create_model,
                         pdf_create_dataloaders)
            else:
                print('Incorrect dataset')
                return -1

    else:

        if args.dataset == 'mog':
            args.input_size = 2
            train_fn(args, artif_run_epoch, artif_eval_fn, create_model,
                     mog_create_dataloaders)
        elif args.dataset == 'circles':
            args.input_size = 2
            train_fn(args, artif_run_epoch, artif_eval_fn, create_model,
                     circles_create_dataloaders)
        elif args.dataset == 'pdf':
            args.input_size = 6
            train_fn(args, pdf_run_epoch, pdf_eval_fn, create_model,
                     pdf_create_dataloaders)
        else:
            print('Incorrect dataset')
            return -1
示例#11
0
#%% Train the model
if os.path.exists(args.exp_dir) == False:
    os.makedirs(args.exp_dir)

if args.restore_file is not None:
    restore_path = os.path.join(args.exp_dir, args.restore_file + '.pth.tar')
    utils.load_checkpoint(restore_path, model, optimizer)

# Create args and output dictionary (for json output)
output_dict = {'args': vars(args), 'prfs': {}}
max_valid_f1 = -float('inf')

for epoch in range(args.num_epochs):
    train_scores = train_fn(model, train_loader, optimizer, scheduler, loss_fn,
                            utils.metrics_fn, device, args.clip,
                            args.accum_step, args.threshold)
    valid_scores = valid_fn(model, valid_loader, loss_fn, utils.metrics_fn,
                            device, args.threshold)

    # Update output dictionary
    output_dict['prfs'][str('train_' + str(epoch + 1))] = train_scores
    output_dict['prfs'][str('valid_' + str(epoch + 1))] = valid_scores

    is_best = valid_scores['f1'] > max_valid_f1
    if is_best == True:
        max_valid_f1 = valid_scores['f1']
        utils.save_dict_to_json(
            valid_scores, os.path.join(args.exp_dir, 'best_val_scores.json'))

    # Save model
示例#12
0
        if fold_number in args.to_run_folds:
            data_train_ = ShopeeDataset(train[train['fold'] != fold_number].reset_index(drop=True))
            data = DataLoader(data_train_, batch_size=args.batch_size,
                              num_workers=args.num_workers)
            data_valid_ = ShopeeDataset(train[train['fold'] == fold_number].reset_index(drop=True))
            data_valid = DataLoader(data_valid_, batch_size=args.batch_size_test,
                                    num_workers=args.num_workers)

            model = ShopeeModel().to(args.device)
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=args.scheduler_params['lr_start'])
            if args.scheduler:
                scheduler = ShopeeScheduler(optimizer, **args.scheduler_params)
            else:
                scheduler = None

            for epoch in range(args.n_epochs):
                model_dir = f'epoch{epoch}_arcface_{args.crop_size}x' \
                            f'{args.crop_size}_{args.backbone}' \
                            f'fold_{fold_number}.pt'
                avg_loss_train = train_fn(model, data, optimizer, scheduler, epoch, args.device)
                avg_loss_valid = eval_fn(model, data_valid, epoch)
                print(
                    f'TRAIN LOSS : {avg_loss_train}  VALIDATION LOSS : {avg_loss_valid}'
                )
                torch.save(model.state_dict(), args.output + model_dir)
                torch.save(dict(epoch=epoch, model_state_dict=model.state_dict(),
                                optimizer=optimizer.state_dict(),scheduler=scheduler.state_dict()),
                           args.output + 'checkpoints_' + model_dir
                           )