Beispiel #1
0
def train( model_file, X_train, X_val, max_n_epochs = 300 ):
	n_ims = X_train.shape[0]
	batch_size = min(X_train.shape[0],8)

	if os.path.isfile(model_file): # continue training
		model = load_encoder_model(model_file)
		start_epoch = int(re.search( '(?<=epoch_)[0-9]*', os.path.basename(model_file)).group(0)) + 1
	else:	# start a new model
		model = networks.make_model( model_file )
		start_epoch = 0

	batch_gen_train, batch_gen_val = make_batch_generators( 'train', X_train, X_val, batch_size, siamese=('siamese' in model_file) )

	run_network( 'train', batch_gen_train, batch_gen_val, X_train.shape[0], X_val.shape[0], (batch_size,)+X_train.shape[1:], model, start_epoch, max_n_epochs, siamese='siamese' in model_file)
Beispiel #2
0
    np_train_dataset, np_val_dataset = np_dataset

    torch.manual_seed(config.seed)

    train_rel_dataset, train_nonrel_dataset = make_dataset(np_train_dataset, rel_augmentation=True)
    train_dataset = torch.utils.data.ConcatDataset([train_rel_dataset, train_nonrel_dataset])
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, pin_memory=True, num_workers=config.n_cpu)

    val_rel_dataset, val_nonrel_dataset = make_dataset(np_val_dataset)
    val_rel_dataloader = torch.utils.data.DataLoader(val_rel_dataset, batch_size=config.batch_size, pin_memory=True)
    val_nonrel_dataloader = torch.utils.data.DataLoader(val_nonrel_dataset, batch_size=config.batch_size, pin_memory=True)


    model_name = '{}_{}_{}'.format(model_dict['n_res_blocks'], model_dict['n_channels'], config.tag)
    tb = tensorboard.tf_recorder(model_name, config.log_dir)
    net = make_model(model_dict).cuda()
    optimizer = optim.Adam(net.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    criterion = nn.CrossEntropyLoss().cuda()

    if config.resume:
        print('load model from {}'.format(config.resume))
        prev = torch.load(config.resume)
        net.load_state_dict(prev['net'])
        optimizer.load_state_dict(prev['optimizer'])
    else:
        init_weights(net, config.init)
        
    for epoch in range(config.n_epoch):
        train_dict = train()
        val_dict = val()