Esempio n. 1
0
def train_loop(model, loss_function, optimizer, scheduler, args):
    
    training_loss = []
    validation_loss = []
    recur_input_indx = list(range(args.nlevs))
    train_ldr = train_dataloader(args)
    test_ldr = test_dataloader(args)
    recursive_train_interval = 1.1
    input_indices = list(range(args.nlevs))
    for epoch in range(1, args.epochs + 1):
        ## Training
        train_loss = 0
        for batch_idx, batch in enumerate(train_ldr):
            # Sets the model into training mode
            model.train()
            if batch_idx%1 == recursive_train_interval:
                loss = recursive_training_step(batch, batch_idx, model, loss_function, optimizer, args.device, recur_input_indx)
            else:
                loss = training_step(batch, batch_idx, model, loss_function, optimizer, args.device, input_indices=input_indices)
            
            train_loss += loss.item()
            if batch_idx % args.log_interval == 0:
                x,y, y2=batch
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.2e}'.format(epoch, 
                batch_idx * len(x), len(train_ldr.dataset)*args.batch_size,100. * batch_idx / len(train_ldr),
                loss.item() / len(x)))
        average_loss = train_loss / len(train_ldr.dataset)
        print('====> Epoch: {} Average loss: {:.2e}'.format(epoch, average_loss))
        training_loss.append(average_loss)
        scheduler.step()
        ## Testing
        test_loss = 0
        for batch_idx, batch in enumerate(test_ldr):
            model.eval()
            loss = validation_step(batch, batch_idx, model, loss_function, args.device, input_indices=input_indices)
            test_loss += loss.item()
        average_loss_val = test_loss / len(test_ldr.dataset) 
        print('====> validation loss: {:.2e}'.format(average_loss_val))
        validation_loss.append(average_loss_val)
        if epoch % 2 == 0:
            checkpoint_name = args.model_name.replace('.tar','_chkepo_{0}.tar'.format(str(epoch).zfill(3)))
            torch.save({'epoch':epoch,
                'model_state_dict':model.state_dict(),
                'optimizer_state_dict':optimizer.state_dict(),
                'training_loss':training_loss,
                'validation_loss':validation_loss,
                'arguments':args},
                args.locations['model_loc']+'/'+checkpoint_name)
            # checkpoint_save(epoch, model, optimizer, training_loss, validation_loss, args.model_name, args.locations, args)            
     # Save the final model
    torch.save({'epoch':epoch,
                'model_state_dict':model.state_dict(),
                'optimizer_state_dict':optimizer.state_dict(),
                'training_loss':training_loss,
                'validation_loss':validation_loss,
                'arguments':args},
                args.locations['model_loc']+'/'+args.model_name)   
    return training_loss, validation_loss
Esempio n. 2
0
def train_model(model, train_dataset, val_dataset, n_epochs):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.L1Loss(reduction='sum')
    #criterion = torch.nn.MSELoss(reduction='mean',reduce=True, size_average=True)
    history = dict(train=[], val=[])

    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 10000.0

    for epoch in range(1, n_epochs + 1):
        model = model.train()

        train_losses = []
        for seq_true in train_dataset:
            optimizer.zero_grad()

            seq_true = seq_true
            seq_pred = model(seq_true)

            loss = criterion(seq_pred, seq_true)

            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())

        val_losses = []
        model = model.eval()
        with torch.no_grad():
            for seq_true in val_dataset:

                seq_true = seq_true
                seq_pred = model(seq_true)

                loss = criterion(seq_pred, seq_true)
                val_losses.append(loss.item())

        train_loss = np.mean(train_losses)
        val_loss = np.mean(val_losses)

        history['train'].append(train_loss)
        history['val'].append(val_loss)

        if val_loss < best_loss:
            best_loss = val_loss
            best_model_wts = copy.deepcopy(model.state_dict())

        print(f'Epoch {epoch}: train loss {train_loss} val loss {val_loss}')

    model.load_state_dict(best_model_wts)
    return model.eval(), history
def save_checkpoint(checkpoint_path, model, optimizer):
    state = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }
    torch.save(state, checkpoint_path)
    print('model saved to %s' % checkpoint_path)
Esempio n. 4
0
def train(args):
    model = model.model(device=args.device)

    optimizer = torch.optim.Adam(q_meta_net.parameters(), lr=0.01)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=1000,
                                                gamma=0.1)

    zero = tensor([0.], device=args.device)
    one = tensor([1.], device=args.device)

    for e in range(args.epochs):
        loss = 0
        # TODO think about numerical stability here
        loss += torch.zeros(1).mean()
        print(e, loss.item()) if e % 100 == 0 else None

        if torch.isnan(loss).any():
            raise Exception(f'NaN loss on epoch {e}, terminating')

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        if e % 1000 == 0:
            torch.save(model.state_dict(), args.path + args.runid)
Esempio n. 5
0
def save_model(model, batch_idx, epoch, info='tmp'):
    date = time.strftime('%m-%d|%H:%M', time.localtime(time.time()))
    name = 'model_%s_%s_lr_%.1e_cur_lr_%s_l2_%.1e_batch_%d_e%d-%d_%s.%s.pt' % (
        opt.model, opt.info, opt.lr, opt.cur_lr, opt.l2, opt.batch_size, epoch,
        batch_idx, date, info)
    torch.save(model.state_dict(), os.path.join(opt.checkpoint, name))
    return name
Esempio n. 6
0
File: train.py Progetto: mfkiwl/fmr
def run(args, trainset, testset, action):
    if not torch.cuda.is_available():
        args.device = 'cpu'
    args.device = torch.device(args.device)

    model = action.create_model()
    if args.store and os.path.isfile(args.store):
        model.load_state_dict(torch.load(args.store, map_location='cpu'))

    if args.pretrained:
        assert os.path.isfile(args.pretrained)
        model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
    model.to(args.device)

    checkpoint = None
    if args.resume:
        assert os.path.isfile(args.resume)
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['model'])

    # dataloader
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.workers)

    # optimizer
    min_loss = float('inf')
    learnable_params = filter(lambda p: p.requires_grad, model.parameters())
    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(learnable_params)
    else:
        optimizer = torch.optim.SGD(learnable_params, lr=0.1)

    if checkpoint is not None:
        min_loss = checkpoint['min_loss']
        optimizer.load_state_dict(checkpoint['optimizer'])

    # training
    LOGGER.debug('train, begin')
    for epoch in range(args.start_epoch, args.epochs):
        running_loss = action.train(model, trainloader, optimizer, args.device)
        val_loss = action.validate(model, testloader, args.device)

        is_best = val_loss < min_loss
        min_loss = min(val_loss, min_loss)

        LOGGER.info('epoch, %04d, %f, %f', epoch + 1, running_loss, val_loss)
        print('epoch, %04d, floss_train=%f, floss_val=%f' %
              (epoch + 1, running_loss, val_loss))

        if is_best:
            save_checkpoint(model.state_dict(), args.outfile, 'model')

    LOGGER.debug('train, end')
def loading(path, model, nGPU):
    state_dict = model.state_dict()
    pretrained = torch.load('./weights/{}/model/model_best.pt'.format(path))
    print('pretrained model loaded~!')
    toupdate = dict()
    total = len(pretrained.keys())
    loaded = 0
    if nGPU < 2:
        for k,v in pretrained.items():
            if 'model.'+k in state_dict.keys():
                toupdate['model.'+k] = v
                loaded += 1
            elif k in state_dict.keys():
                toupdate[k] = v
                loaded += 1
    else:
        for k,v in pretrained.items():
            if 'model.module.'+k in state_dict.keys():
                toupdate['model.module.'+k] = v
                loaded += 1
            elif k in state_dict.keys():
                toupdate[k] = v
                loaded += 1
    print('total params: ', total, ', loaded params: ', loaded)
    state_dict.update(toupdate)
    model.load_state_dict(state_dict)
def save_mode(epoch=None,
              model=None,
              optimizer=None,
              test_acc=None,
              best_acc=None,
              test_acc_top5=None,
              filename=None,
              global_steps=0):
    global NORM_MEAN, NORM_STD, coconut_model, train_history_dict, class_to_idx, global_steps_train_history_dict
    state = {
        'epoch': epoch + 1,
        'args': args,
        'test_acc': test_acc,
        'best_acc': best_acc,
        'test_acc_top5': test_acc_top5,
        'class_to_idx': class_to_idx,
        'NORM_MEAN': NORM_MEAN,
        'NORM_STD': NORM_STD,
        'global_steps': global_steps,
        'train_history_dict': train_history_dict,
        'global_steps_train_history_dict': global_steps_train_history_dict,
        'model_state_dict': model.state_dict(),
        'model_optimizer': optimizer.state_dict()
    }
    torch.save(state, filename)
    print(filename + ' Saved!')
    return
Esempio n. 9
0
def train(data, epochs=100, batchSize=32, learningRate=1e-3):
    data = torch.from_numpy(data).cuda()

    dataset = TensorDataset(data)

    dataLoader = DataLoader(dataset, batch_size=batchSize, shuffle=True)

    model = EncoderDecoder().cuda()
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=learningRate,
                                 weight_decay=1e-5)

    for epoch in range(epochs):
        for data in dataLoader:
            data = data[0]
            output = model(data)
            loss = criterion(output, data)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print('Epoch [%d/%d], loss:%.4f' %
              (epoch + 1, epochs, loss.data.item()))
        if epoch % 2 == 0:
            pic = output.cpu().data
            save_image(pic, 'outputs/%d.png' % (epoch))

    torch.save(model.state_dict(), 'models/encoderDecoder.pth')
Esempio n. 10
0
def save_model(model, optimizer):
    
    os.environ['TZ'] = 'Singapore'
    time.tzset()
    date = time.strftime("%Y%m%d")
    timestamp = time.strftime("%H_%M_%S")
    if not os.path.exists('./saved_models/' + date):
        os.makedirs('./saved_models/' + date)
    path = './saved_models/' + date + '/' + timestamp
    
    checkpoint = {'state_dict': model.state_dict(),
                  'opti_state_dict': optimizer.state_dict(),
                  'model_lr': model.lr,
                  'model_nodes_num': model.nodes_num,
                  'model_features_num': model.features_num,
                  'model_input_timesteps': model.input_timesteps,
                  'model_num_output': model.num_output
                  }
    
    torch.save(checkpoint, path)
    
    f = open("./saved_models/last_saved_model.txt", "w")
    f.write(path)
    f.close()
    print(f"Model has been saved to path : {path}")
Esempio n. 11
0
def run(args, model, device, train_loader, test_loader, scheduler, optimizer):
    global best_prec1
    for epoch in range(args.start_epoch,
                       args.epochs):  # loop over the dataset multiple times
        train(args, model, device, train_loader, optimizer, epoch)
        prec1, prec5, loss = test(args, model, device, test_loader)
        # scheduler
        if args.scheduler == "MultiStepLR":
            scheduler.step()
        elif args.scheduler == "ReduceLROnPlateau":
            scheduler.step(loss)
        else:
            pass

        # remember the best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if args.detail and is_best:
            save_checkpoint(
                {
                    'epoch': args.epochs + 1,
                    'arch': args.model_type,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict()
                }, is_best, args.checkpoint_path + '_' + args.model_type +
                '_' + str(args.model_structure))
Esempio n. 12
0
def save(model, opt, save_dir, epoch, args):
    os.makedirs(save_dir, exist_ok=True)

    # Save the current epoch
    save_filename = os.path.join(save_dir, "checkpoint-%d.pt" % epoch)
    torch.save({
        'gnet': model.state_dict(),
        'gopt': opt.state_dict()
    }, save_filename)

    # Save the most recent
    save_filename = os.path.join(save_dir, "checkpoint.pt")
    torch.save({
        'gnet': model.state_dict(),
        'gopt': opt.state_dict()
    }, save_filename)
Esempio n. 13
0
def train_model(model, optimizer, L, log_file, num_epochs=8000, ps=512):
    a = True
    for epoch in range(num_epochs):
        # after traing 2000 epoch chenge lr to 1e-5
        if a and epoch > 2000:
            for g in optimizer.param_groups:
                g['lr'] = 0.00001
            a = False
        epoch_loss = 0
        step = 0
        dt_size = len(train_ids)
        for ind in np.random.permutation(len(train_ids)):
            step += 1
            input_patch, gt = loaddata(ind, ps)
            optimizer.zero_grad()
            output = model(input_patch)
            loss = L(output, gt)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            # print("%d/%d,train_loss:%0.3f" % (step, len(train_ids) + 1, loss.item()))
        print("epoch %d loss:%0.3f" % (epoch, epoch_loss / step))
        log_file.write("epoch %d loss:%0.3f\n" % (epoch, epoch_loss / step))
        if (epoch + 1) % 200 == 0:
            log_file.flush()
            torch.save(model.state_dict(), './save/weights_%d.pth' % epoch)
    return model
Esempio n. 14
0
def train(data_loader, model, train=True):
    print("Training started ... ")
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(),
                           lr=C.LEARNING_RATE,
                           weight_decay=C.WEIGHT_DECAY)
    model.train()
    for epoch in range(C.EPOCH):
        print("Epoch %d started" % (epoch))
        running_loss = 0.0
        start_time = time.time()
        for batch_index, (train, label) in enumerate(data_loader):

            optimizer.zero_grad()  # Reset the gradients
            prediction = model(train)  # Feed forward
            loss = criterion(prediction, label.long())  # Compute losses
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            if batch_index % C.PRINT_TRAIN_INFO_INTERVAL == C.PRINT_TRAIN_INFO_INTERVAL - 1:
                print(
                    '\t[%d, %5d/%5d] loss: %.3f  . Time %5dsec' %
                    (epoch, batch_index, data_loader.__len__(), running_loss /
                     C.PRINT_TRAIN_INFO_INTERVAL, time.time() - start_time))
                running_loss = 0.0

            if batch_index % C.SAVE_MODEL_INTERVAL == C.SAVE_MODEL_INTERVAL - 1:
                print('\tSaving model models/CassavaImagesDataset-%s-%s.pt' %
                      (str(epoch), str(batch_index)))
                torch.save(
                    model.state_dict(), "models/CassavaImagesDataset-" +
                    str(epoch) + "-" + str(batch_index) + ".pt")
    print("Training ended ... ")
Esempio n. 15
0
def main():
    # train and valid model
    train_loss_history = []
    valid_loss_history = []
    train_acc_history = []
    valid_acc_history = []
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch+1, epochs))
        print('-' * 10)
        train_loss, train_acc = train_model()
        valid_loss, valid_acc = valid_model()
        train_loss_history.append(train_loss)
        valid_loss_history.append(valid_loss)
        train_acc_history.append(train_acc)
        valid_acc_history.append(valid_acc)
        if valid_acc>best_acc:
            best_acc = valid_acc
            best_model_wts = copy.deepcopy(model.state_dict())
    # save best accuracy on valid dataset
    print('Best val Acc: {:4f}'.format(best_acc))
    if not os.path.exists(args.tag+'_backup'):
        os.mkdir(args.tag+'_backup')
    torch.save(best_model_wts, args.tag+'_backup/checkpoint.pt')
    
    # visualise loss diagram and accuracy diagram
    plt.figure(1)
    plt.plot(train_loss_history)
    plt.plot(valid_loss_history)
    maxposs = valid_acc_history.index(max(valid_acc_history))+1 
    plt.axvline(maxposs, linestyle='--', color='r')
    plt.gca().legend(('Train','Validation', 'Checkpoint'))
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.figure(2)
    plt.plot(train_acc_history)
    plt.plot(valid_acc_history)
    maxposs = valid_acc_history.index(max(valid_acc_history))+1 
    plt.axvline(maxposs, linestyle='--', color='r')
    plt.gca().legend(('Train','Validation','Checkpoint'))
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy: %')
    plt.show()

    # test model
    test_model()
Esempio n. 16
0
def save_model(base_path , model_name, model, optimizer, scheduler, opt):
    state = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict()
    }
    torch.save(state, os.path.join(base_path , model_name + '.pth'))
    print('model saved')
Esempio n. 17
0
 def load_vae(self, path):
     model = self.decoder.policy
     model_keys = set(model.state_dict().keys())
     states = torch.load(path)
     _state = {}
     for k in model_keys:
         _state[k]=states[k]
     model.load_state_dict(_state)
def control(queue, log, types, data, fout, distfn, nepochs, processes):
    min_rank = (np.Inf, -1)
    max_map = (0, -1)
    while True:
        gc.collect()
        msg = queue.get()
        if msg is None:
            for p in processes:
                p.terminate()
            break
        else:
            epoch, elapsed, loss, model = msg
        if model is not None:
            # save model to fout
            th.save(
                {
                    'model': model.state_dict(),
                    'epoch': epoch,
                    'objects': data.objects,
                }, fout)
            # compute embedding quality
            # mrank, mAP = ranking(types, model, distfn)  # TODO : SB
            ranks, ap_scores = ranking(types, model, distfn)
            mrank, mAP = np.mean(ranks), np.mean(ap_scores)
            if mrank < min_rank[0]:
                min_rank = (mrank, epoch)
            if mAP > max_map[0]:
                max_map = (mAP, epoch)
            log.info(
                ('eval: {'
                 '"epoch": %d, '
                 '"elapsed": %.2f, '
                 '"loss": %.3f, '
                 '"mean_rank": %.2f, '
                 '"mAP": %.4f, '
                 '"best_rank": %.2f, '
                 '"best_mAP": %.4f}') %
                (epoch, elapsed, loss, mrank, mAP, min_rank[0], max_map[0]))

            if model.dim == 2:
                plot_emb(types, model, epoch, data.objects, ap_scores, loss,
                         min_rank[0], max_map[0])
        else:
            log.info(
                f'json_log: {{"epoch": {epoch}, "loss": {loss}, "elapsed": {elapsed}}}'
            )
        if epoch >= nepochs - 1:
            log.info(
                ('results: {'
                 '"mAP": %g, '
                 '"mAP epoch": %d, '
                 '"mean rank": %g, '
                 '"mean rank epoch": %d'
                 '}') % (max_map[0], max_map[1], min_rank[0], min_rank[1]))
            break
    if model.dim == 2:
        plot_emb(types, model, epoch, data.objects, ap_scores, loss,
                 min_rank[0], max_map[0])
Esempio n. 19
0
def save_model(path, model, optimizer, step, args):
    torch.save(
        {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "step": step,
            "args": args,
        }, path + ".temp")
    os.rename(path + ".temp", path)  # Replace atomically
Esempio n. 20
0
def train(device, dataset, dataloader, model):
    print("in train")
    model = model.to(device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

    # Training loop
    images_per_batch = {}
    batch_count, images_per_batch['train'], images_per_batch[
        'test'] = 0, [], []
    with tqdm(dataloader, total=config.num_batches) as pbar:
        for batch_idx, batch in enumerate(pbar):
            model.zero_grad()

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device=device)
            train_targets = train_targets.to(device=device)
            train_embeddings = model(train_inputs)

            test_inputs, test_targets = batch['test']
            test_inputs = test_inputs.to(device=device)
            test_targets = test_targets.to(device=device)
            test_embeddings = model(test_inputs)

            prototypes = get_prototypes(train_embeddings, train_targets,
                                        dataset.num_classes_per_task)
            loss = prototypical_loss(prototypes, test_embeddings, test_targets)

            loss.backward()
            optimizer.step()

            #Just keeping the count here
            batch_count += 1
            images_per_batch['train'].append(train_inputs.shape[1])
            images_per_batch['test'].append(test_inputs.shape[1])

            with torch.no_grad():
                accuracy = get_accuracy(prototypes, test_embeddings,
                                        test_targets)
                pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))

            if batch_idx >= config.num_batches:
                break

    print("Number of batches in the dataloader: ", batch_count)

    # Save model
    if check_dir() is not None:
        filename = os.path.join(
            'saved_models',
            'protonet_cifar_fs_{0}shot_{1}way.pt'.format(config.k, config.n))
        with open(filename, 'wb') as f:
            state_dict = model.state_dict()
            torch.save(state_dict, f)
            print("Model saved")

    return batch_count, images_per_batch
Esempio n. 21
0
def save_model(model_name, model, optimizer, scheduler):
    make_folder(model_dir)
    state = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict()
    }
    torch.save(state, os.path.join(model_dir, model_name + '.pth'))
    print('model saved')
Esempio n. 22
0
def save_model(model, model_name, niter, save_dir):
    save_name = '{}_{}.pth'.format(model_name, niter)
    if isinstance(model, nn.DataParallel) or isinstance(
            model, nn.parallel.DistributedDataParallel):
        model = model.module
    state_dict = model.state_dict()
    for key, param in state_dict.items():
        state_dict[key] = param.cpu()
    torch.save(state_dict, os.path.join(save_dir, save_name))
Esempio n. 23
0
def save(model, optimizer, loss_list, epoch, dset_name, best, location):
    path = location + '/' + dset_name + ".tar"
    prefix_list = ["best_", "last_"] if best else ["last_"]
    for prefix in prefix_list:
        torch.save({
            prefix + 'vae_state_dict': model.state_dict(),
            prefix + 'optim_state_dict': optimizer.state_dict(),
            prefix + 'loss_list': loss_list,
            prefix + 'epoch': epoch
        }, path)
Esempio n. 24
0
def save_checkpoint(model: nn.Module, optimizer: optim.Optimizer, epoch: int,
                    loss: float, filepath: str):
    """Saves model and optimizer state to a filepath."""
    torch.save(
        {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'loss': loss,
            'optimizer_state_dict': optimizer.state_dict(),
        }, filepath)
Esempio n. 25
0
 def recover_state(name, model, optimizer):
     # print(name)
     # print(list(model.state_dict().keys()))
     # for key in list(model.state_dict().keys()):
     #     print(key, model.state_dict()[key].size())
     state = model.state_dict()
     state.update(states[name]['state_dict'])
     model.load_state_dict(state)
     if args.loadOptim:
         optimizer.load_state_dict(states[name]['optimizer'])
Esempio n. 26
0
def save_model():
    global model, optimizer, model_config, model_path
    print('Saving to', model_path)
    torch.save(
        {
            'model_config': model_config,
            'model_state': model.state_dict(),
            'model_optimizer_state': optimizer.state_dict()
        }, model_path)
    print('完成儲存')
Esempio n. 27
0
 def recover_state(name, model, optimizer):
     state = model.state_dict()
     model_keys = set(state.keys())
     load_keys = set(states[name]['state_dict'].keys())
     if model_keys != load_keys:
         print("NOTICE: DIFFERENT KEYS IN THE LISTEREN")
     state.update(states[name]['state_dict'])
     model.load_state_dict(state)
     if args.loadOptim:
         optimizer.load_state_dict(states[name]['optimizer'])
Esempio n. 28
0
def resnet50(config_channels, anchors, num_cls, **kwargs):
    model = ResNet(config_channels, anchors, num_cls, Bottleneck, [3, 4, 6, 3], **kwargs)
    if config_channels.config.getboolean('model', 'pretrained'):
        url = _model.model_urls['resnet50']
        logging.info('use pretrained model: ' + url)
        state_dict = model.state_dict()
        for key, value in model_zoo.load_url(url).items():
            if key in state_dict:
                state_dict[key] = value
        model.load_state_dict(state_dict)
    return model
def resnet18(config_channels, anchors, num_cls, **kwargs):
    model = ResNet(config_channels, anchors, num_cls, BasicBlock, [2, 2, 2, 2], **kwargs)
    if config_channels.config.getboolean('model', 'pretrained'):
        url = _model.model_urls['resnet18']
        logging.info('use pretrained model: ' + url)
        state_dict = model.state_dict()
        for key, value in model_zoo.load_url(url).items():
            if key in state_dict:
                state_dict[key] = value
        model.load_state_dict(state_dict)
    return model
Esempio n. 30
0
def save():
    model_state_dict = model.state_dict()
    model_source = {
        "settings": args,
        "model": model_state_dict,
        "word2idx": data['word2idx'],
        "char2idx": data['char2idx'],
        "max_len": data["max_len"],
        "predicate2id": data["predicate2id"],
    }
    torch.save(model_source, f"{os.path.join(args.model_path, 'model.pt')}")
Esempio n. 31
0
def resnet152(config_channels, **kwargs):
    model = ResNet(config_channels, Bottleneck, [3, 8, 36, 3], **kwargs)
    if config_channels.config.getboolean('model', 'pretrained'):
        url = _model.model_urls['resnet152']
        logging.info('use pretrained model: ' + url)
        state_dict = model.state_dict()
        for key, value in model_zoo.load_url(url).items():
            if key in state_dict:
                state_dict[key] = value
        model.load_state_dict(state_dict)
    return model
Esempio n. 32
0
def densenet201(config_channels, anchors, num_cls, **kwargs):
    model = DenseNet(config_channels, anchors, num_cls, num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs)
    if config_channels.config.getboolean('model', 'pretrained'):
        url = model_urls['densenet201']
        logging.info('use pretrained model: ' + url)
        state_dict = model.state_dict()
        for key, value in model_zoo.load_url(url).items():
            if key in state_dict:
                state_dict[key] = value
        model.load_state_dict(state_dict)
    return model
Esempio n. 33
0
 def finetune(self, model, path):
     if os.path.isdir(path):
         path, _step, _epoch = utils.train.load_model(path)
     _state_dict = torch.load(path, map_location=lambda storage, loc: storage)
     state_dict = model.state_dict()
     ignore = utils.RegexList(self.args.ignore)
     for key, value in state_dict.items():
         try:
             if not ignore(key):
                 state_dict[key] = _state_dict[key]
         except KeyError:
             logging.warning('%s not in finetune file %s' % (key, path))
     model.load_state_dict(state_dict)