示例#1
0
def main(args):
    assert os.path.isfile(args.checkpoint), "Checkpoint file not found: {}".format(args.checkpoint)

    args.cuda = not args.no_cuda and torch.cuda.is_available()

    test_transforms = transforms.Compose([transforms.Resize((128, 128)),
                                          transforms.ToTensor()])

    # Initialize CLEVR Loader
    clevr_dataset_images = ClevrDatasetImages(args.clevr_dir, 'val', test_transforms)
    clevr_feat_extraction_loader = DataLoader(clevr_dataset_images, batch_size=args.batch_size,
                                              shuffle=False, num_workers=8, drop_last=True)

    args.features_dirs = './features'
    if not os.path.exists(args.features_dirs):
        os.makedirs(args.features_dirs)

    max_features = os.path.join(args.features_dirs, 'max_features.pickle')
    avg_features = os.path.join(args.features_dirs, 'avg_features.pickle')

    print('Building word dictionaries from all the words in the dataset...')
    dictionaries = utils.build_dictionaries(args.clevr_dir)
    print('Word dictionary completed!')

    args.qdict_size = len(dictionaries[0])
    args.adict_size = len(dictionaries[1])
    model = RN(args)

    if torch.cuda.device_count() > 1 and args.cuda:
        model = torch.nn.DataParallel(model)
        model.module.cuda()  # call cuda() overridden method

    if args.cuda:
        model.cuda()

    # Load the model checkpoint
    print('==> loading checkpoint {}'.format(args.checkpoint))
    checkpoint = torch.load(args.checkpoint)

    #removes 'module' from dict entries, pytorch bug #3805
    checkpoint = {k.replace('module.',''): v for k,v in checkpoint.items()}

    model.load_state_dict(checkpoint)
    print('==> loaded checkpoint {}'.format(args.checkpoint))

    max_features = open(max_features, 'wb')
    avg_features = open(avg_features, 'wb')

    extract_features_rl(clevr_feat_extraction_loader, max_features, avg_features, model, args)
示例#2
0
def main(args):
    args.model_dirs = './model_{}_b{}_lr{}'.format(args.model, args.batch_size,
                                                   args.lr)
    args.features_dirs = './features'
    if not os.path.exists(args.model_dirs):
        os.makedirs(args.model_dirs)

    args.test_results_dir = './test_results'
    if not os.path.exists(args.test_results_dir):
        os.makedirs(args.test_results_dir)

    args.cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print('Building word dictionaries from all the words in the dataset...')
    dictionaries = utils.build_dictionaries(args.clevr_dir)
    print('Word dictionary completed!')

    print('Initializing CLEVR dataset...')

    if (not args.state_description):
        train_transforms = transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.Pad(8),
            transforms.RandomCrop((128, 128)),
            transforms.RandomRotation(2.8),  # .05 rad
            transforms.ToTensor()
        ])
        test_transforms = transforms.Compose(
            [transforms.Resize((128, 128)),
             transforms.ToTensor()])

        clevr_dataset_train = ClevrDataset(args.clevr_dir, True, dictionaries,
                                           train_transforms)
        clevr_dataset_test = ClevrDataset(args.clevr_dir, False, dictionaries,
                                          test_transforms)

        # Use a weighted sampler for training:
        weights = clevr_dataset_train.answer_weights()
        sampler = torch.utils.data.sampler.WeightedRandomSampler(
            weights, len(weights))

        # Initialize Clevr dataset loaders
        clevr_train_loader = DataLoader(clevr_dataset_train,
                                        batch_size=args.batch_size,
                                        sampler=sampler,
                                        num_workers=8,
                                        collate_fn=utils.collate_samples_image)
        clevr_test_loader = DataLoader(clevr_dataset_test,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       num_workers=8,
                                       collate_fn=utils.collate_samples_image)
    else:
        clevr_dataset_train = ClevrDatasetStateDescription(
            args.clevr_dir, True, dictionaries)
        clevr_dataset_test = ClevrDatasetStateDescription(
            args.clevr_dir, False, dictionaries)

        # Initialize Clevr dataset loaders
        clevr_train_loader = DataLoader(
            clevr_dataset_train,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=8,
            collate_fn=utils.collate_samples_state_description)
        clevr_test_loader = DataLoader(
            clevr_dataset_test,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=8,
            collate_fn=utils.collate_samples_state_description)

    print('CLEVR dataset initialized!')

    # Build the model
    args.qdict_size = len(dictionaries[0])
    args.adict_size = len(dictionaries[1])
    model = RN(args)

    if torch.cuda.device_count() > 1 and args.cuda:
        model = torch.nn.DataParallel(model)
        model.module.cuda()  # call cuda() overridden method

    if args.cuda:
        model.cuda()

    start_epoch = 1
    if args.resume:
        filename = args.resume
        if os.path.isfile(filename):
            print('==> loading checkpoint {}'.format(filename))
            checkpoint = torch.load(filename)

            #removes 'module' from dict entries, pytorch bug #3805
            #checkpoint = {k.replace('module.',''): v for k,v in checkpoint.items()}

            model.load_state_dict(checkpoint)
            print('==> loaded checkpoint {}'.format(filename))
            start_epoch = int(
                re.match(r'.*epoch_(\d+).pth', args.resume).groups()[0]) + 1

    if args.conv_transfer_learn:
        if os.path.isfile(args.conv_transfer_learn):
            # TODO: there may be problems caused by pytorch issue #3805 if using DataParallel

            print('==> loading conv layer from {}'.format(
                args.conv_transfer_learn))
            # pretrained dict is the dictionary containing the already trained conv layer
            pretrained_dict = torch.load(args.conv_transfer_learn)

            if torch.cuda.device_count() == 1:
                conv_dict = model.conv.state_dict()
            else:
                conv_dict = model.module.conv.state_dict()

            # filter only the conv layer from the loaded dictionary
            conv_pretrained_dict = {
                k.replace('conv.', '', 1): v
                for k, v in pretrained_dict.items() if 'conv.' in k
            }

            # overwrite entries in the existing state dict
            conv_dict.update(conv_pretrained_dict)

            # load the new state dict
            if torch.cuda.device_count() == 1:
                model.conv.load_state_dict(conv_dict)
                params = model.conv.parameters()
            else:
                model.module.conv.load_state_dict(conv_dict)
                params = model.module.conv.parameters()

            # freeze the weights for the convolutional layer by disabling gradient evaluation
            # for param in params:
            #     param.requires_grad = False

            print("==> conv layer loaded!")
        else:
            print('Cannot load file {}'.format(args.conv_transfer_learn))

    progress_bar = trange(start_epoch, args.epochs + 1)
    if args.test:
        # perform a single test
        print('Testing epoch {}'.format(start_epoch))
        test(clevr_test_loader, model, start_epoch, dictionaries, args)
    else:
        # perform a full training
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=args.lr,
                               weight_decay=1e-4)
        print('Training ({} epochs) is starting...'.format(args.epochs))
        for epoch in progress_bar:
            # TRAIN
            progress_bar.set_description('TRAIN')
            train(clevr_train_loader, model, optimizer, epoch, args)
            # TEST
            progress_bar.set_description('TEST')
            test(clevr_test_loader, model, epoch, dictionaries, args)
            # SAVE MODEL
            filename = 'RN_epoch_{:02d}.pth'.format(epoch)
            torch.save(model.state_dict(),
                       os.path.join(args.model_dirs, filename))
示例#3
0
def main(args):
    #load hyperparameters from configuration file
    with open(args.config) as config_file:
        hyp = json.load(config_file)['hyperparams'][args.model]
    #override configuration dropout
    if args.question_injection >= 0:
        hyp['question_injection_position'] = args.question_injection

    print('Loaded hyperparameters from configuration {}, model: {}: {}'.format(
        args.config, args.model, hyp))

    assert os.path.isfile(
        args.checkpoint), "Checkpoint file not found: {}".format(
            args.checkpoint)

    args.cuda = not args.no_cuda and torch.cuda.is_available()

    # Initialize CLEVR Loader
    clevr_dataset_test = initialize_dataset(
        args.clevr_dir, True if args.set == 'train' else False,
        hyp['state_description'])
    clevr_feat_extraction_loader = reload_loaders(clevr_dataset_test,
                                                  args.batch_size,
                                                  hyp['state_description'])

    args.features_dirs = './features'
    if not os.path.exists(args.features_dirs):
        os.makedirs(args.features_dirs)

    files_dict = {}
    if args.extr_layer_idx >= 0:  #g_layers features
        files_dict['max_features'] = \
            open(os.path.join(args.features_dirs, '{}_2S-RN_max_features.pickle'.format(args.set,args.extr_layer_idx)),'wb')
        files_dict['avg_features'] = \
            open(os.path.join(args.features_dirs, '{}_2S-RN_avg_features.pickle'.format(args.set,args.extr_layer_idx)),'wb')
    else:
        '''files_dict['flatconv_features'] = \
            open(os.path.join(args.features_dirs, '{}_flatconv_features.pickle'.format(args.set)),'wb')'''
        files_dict['avgconv_features'] = \
            open(os.path.join(args.features_dirs, '{}_RN_avg_features.pickle'.format(args.set)),'wb')
        files_dict['maxconv_features'] = \
            open(os.path.join(args.features_dirs, '{}_RN_max_features.pickle'.format(args.set)),'wb')

    print('Building word dictionaries from all the words in the dataset...')
    dictionaries = utils.build_dictionaries(args.clevr_dir)
    print('Word dictionary completed!')
    args.qdict_size = len(dictionaries[0])
    args.adict_size = len(dictionaries[1])

    print('Cuda: {}'.format(args.cuda))
    model = RN(args, hyp, extraction=True)

    if torch.cuda.device_count() > 1 and args.cuda:
        model = torch.nn.DataParallel(model)
        model.module.cuda()  # call cuda() overridden method

    if args.cuda:
        model.cuda()

    # Load the model checkpoint
    print('==> loading checkpoint {}'.format(args.checkpoint))
    checkpoint = torch.load(args.checkpoint,
                            map_location=lambda storage, loc: storage)

    #removes 'module' from dict entries, pytorch bug #3805
    #removes 'module' from dict entries, pytorch bug #3805
    if torch.cuda.device_count() == 1 and any(
            k.startswith('module.') for k in checkpoint.keys()):
        print('Removing \'module.\' prefix')
        checkpoint = {
            k.replace('module.', ''): v
            for k, v in checkpoint.items()
        }
    if torch.cuda.device_count() > 1 and not any(
            k.startswith('module.') for k in checkpoint.keys()):
        print('Adding \'module.\' prefix')
        checkpoint = {'module.' + k: v for k, v in checkpoint.items()}

    model.load_state_dict(checkpoint)
    print('==> loaded checkpoint {}'.format(args.checkpoint))

    extract_features_rl(clevr_feat_extraction_loader,
                        hyp['question_injection_position'],
                        args.extr_layer_idx, hyp['lstm_hidden'], files_dict,
                        model, args)
示例#4
0
class Task():
    def __init__(self, args):
        print '#' * 60
        print ' ' * 20 + '    Task Created    ' + ' ' * 20
        print '#' * 60

        ######################################################################################################
        # Parameters
        self.batchSize = args.batchSize
        self.lr = args.lr
        self.weightDecay = 1e-4

        self.objNumMax = 30
        self.wordEmbeddingDim = 128
        self.instructionLength = 10
        self.pinMemory = True
        self.dropout = False

        self.epoch = args.epoch
        self.epoch_i = 0

        self.batchPrint = 100
        self.batchModelSave = args.batchModelSave
        self.checkPoint = args.checkPoint

        # Path
        self.vocabularyPath = './data/vocabulary.json'
        self.trainDatasetPath = './data/generated_data_train.json'
        self.testDatasetPath = './data/generated_data_test.json'
        self.logPath = args.logPath

        # Dataset
        self.trainDataset = DatasetGenerator(
            datasetPath=self.trainDatasetPath,
            vocabularyPath=self.vocabularyPath)
        self.testDataset = DatasetGenerator(datasetPath=self.testDatasetPath,
                                            vocabularyPath=self.vocabularyPath)

        # Tokenizer
        self.tokenizer = Tokenizer(vocabPath=self.vocabularyPath)
        self.num_embedding = self.tokenizer.get_num_embedding()

        # DataLoader
        self.trainDataLoader = DataLoader(dataset=self.trainDataset,
                                          shuffle=True,
                                          batch_size=self.batchSize,
                                          num_workers=12,
                                          pin_memory=self.pinMemory)
        self.testDataLoader = DataLoader(dataset=self.testDataset,
                                         shuffle=False,
                                         batch_size=self.batchSize,
                                         num_workers=12,
                                         pin_memory=self.pinMemory)
        # calculate batch numbers
        self.trainBatchNum = int(
            np.ceil(len(self.trainDataset) / float(self.batchSize)))
        self.testBatchNum = int(
            np.ceil(len(self.testDataset) / float(self.batchSize)))

        # Create model
        self.RN = RN(num_embedding=self.num_embedding,
                     embedding_dim=self.wordEmbeddingDim,
                     obj_num_max=self.objNumMax)

        # Run task on all available GPUs
        if torch.cuda.is_available():
            if torch.cuda.device_count() > 1:
                print("Use ", torch.cuda.device_count(), " GPUs")
                self.RN = nn.DataParallel(self.RN)
            self.RN = self.RN.cuda()
            print 'Model Created on GPUs.'

        # Optermizer
        self.optimizer = optim.Adam(self.RN.parameters(),
                                    lr=self.lr,
                                    betas=(0.9, 0.999),
                                    eps=1e-08,
                                    weight_decay=self.weightDecay)
        # Scheduler
        self.scheduler = ReduceLROnPlateau(self.optimizer,
                                           factor=0.1,
                                           patience=10,
                                           mode='min')

        # Loss Function
        self.loss = torch.nn.CrossEntropyLoss()

        # Load model if a checkPoint is given
        if self.checkPoint != "":
            self.load(self.checkPoint)

        # TensorboardX record
        self.writer = SummaryWriter()
        self.stepCnt_train = 1
        self.stepCnt_test = 1

    def train(self):

        print 'Training task begin.'
        print '----Batch Size: %d' % self.batchSize
        print '----Learning Rate: %f' % (self.lr)
        print '----Epoch: %d' % self.epoch
        print '----Log Path: %s' % self.logPath

        for self.epoch_i in range(self.epoch):
            self.epochTrain()
            self.test()

    def epochTrain(self):
        s = '#' * 30 + '    Epoch %3d / %3d    ' % (self.epoch_i + 1,
                                                    self.epoch) + '#' * 30
        print s
        bar = tqdm(self.trainDataLoader)
        for idx, (objs_coordinate, objs_category, objs_category_idx,
                  instruction, instruction_idx, target,
                  data) in enumerate(bar):

            batchSize = objs_coordinate.shape[0]

            # to cuda
            if torch.cuda.is_available():
                objs_coordinate = objs_coordinate.cuda()
                objs_category_idx = objs_category_idx.long().cuda()
                instruction_idx = instruction_idx.long().cuda()
                target = target.cuda()

            # Go through the model
            output = self.RN(objs_coordinate, objs_category_idx,
                             instruction_idx)

            # calculate loss
            lossValue = self.loss(input=output, target=target)

            # Tensorboard record
            self.writer.add_scalar('Loss/Train', lossValue.item(),
                                   self.stepCnt_train)

            # print loss
            bar.set_description('Epoch: %d    Loss: %f' %
                                (self.epoch_i + 1, lossValue.item()))

            # Backward
            self.optimizer.zero_grad()
            lossValue.backward()
            self.optimizer.step()
            # self.scheduler.step(lossValue)

            # Save model
            if (idx + 1) % self.batchModelSave == 0:
                self.save(batchIdx=(idx + 1))

            if idx % self.batchPrint == 0:
                output = output.detach().cpu().numpy()
                target = target.detach().cpu().numpy()
                s = ''
                for batch_i in range(batchSize):
                    s += 'Target: '
                    s += str(target[batch_i])
                    s += ' Output: '
                    s += ', '.join(
                        [str(i) for i in output[batch_i, :].tolist()])
                    s += '    #########    '
                self.writer.add_text('Target & Output', s, self.stepCnt_train)

                self.writer.add_histogram('output', output, self.stepCnt_train)
                self.writer.add_histogram('target', target, self.stepCnt_train)

                for name, param in self.RN.named_parameters():
                    self.writer.add_histogram(name,
                                              param.clone().cpu().data.numpy(),
                                              self.stepCnt_train)

            self.stepCnt_train += 1
            del lossValue

    def test(self):
        s = '#' * 28 + '  Test  Epoch %3d / %3d    ' % (self.epoch_i + 1,
                                                        self.epoch) + '#' * 28
        print s

        bar = tqdm(self.testDataLoader)
        for idx, (objs_coordinate, objs_category, objs_category_idx,
                  instruction, instruction_idx, target,
                  data) in enumerate(bar):

            batchSize = objs_coordinate.shape[0]

            # to cuda
            if torch.cuda.is_available():
                objs_coordinate = objs_coordinate.cuda()
                objs_category_idx = objs_category_idx.cuda()
                instruction_idx = instruction_idx.cuda()
                target = target.cuda()

            # Go through the model
            output = self.RN(objs_coordinate, objs_category_idx,
                             instruction_idx)

            # calculate loss
            lossValue = self.loss(input=output, target=target)
            # Tensorboard record
            self.writer.add_scalar('Loss/Test', lossValue.item(),
                                   self.stepCnt_test)

            # print loss
            bar.set_description('Epoch: %d    Loss: %f' %
                                (self.epoch_i + 1, lossValue.item()))

            # Save model
            if (idx + 1) % self.batchModelSave == 0:
                self.save(batchIdx=(idx + 1))

            if idx % self.batchPrint == 0:
                output = output.detach().cpu().numpy()
                target = target.detach().cpu().numpy()
                s = ''
                for batch_i in range(batchSize):
                    s += 'Target: '
                    s += str(target[batch_i])
                    s += ' Output: '
                    s += ', '.join(
                        [str(i) for i in output[batch_i, :].tolist()])
                    s += '    #########    '
                self.writer.add_text('Target & Output Test', s,
                                     self.stepCnt_test)

                self.writer.add_histogram('output', output, self.stepCnt_test)
                self.writer.add_histogram('target', target, self.stepCnt_test)

                # for name, param in self.RN.named_parameters():
                #     self.writer.add_histogram(name, param.clone().cpu().data.numpy(), self.stepCnt_test)

            self.stepCnt_test += 1
            del lossValue

    def save(self, batchIdx=None):
        dirPath = os.path.join(self.logPath, 'models')

        if not os.path.exists(dirPath):
            os.mkdir(dirPath)

        if batchIdx is None:
            path = os.path.join(dirPath,
                                'Epoch-%03d-end.pth.tar' % (self.epoch_i + 1))
        else:
            path = os.path.join(
                dirPath,
                'Epoch-%03d-Batch-%04d.pth.tar' % (self.epoch_i + 1, batchIdx))

        torch.save(
            {
                'epochs': self.epoch_i + 1,
                'batch_size': self.batchSize,
                'lr': self.lr,
                'weight_dacay': self.weightDecay,
                'RN_model_state_dict': self.RN.state_dict()
            }, path)
        print 'Training log saved to %s' % path

    def load(self, path):
        modelCheckpoint = torch.load(path)
        self.RN.load_state_dict(modelCheckpoint['RN_model_state_dict'])
        print 'Load model from %s' % path
示例#5
0
文件: test.py 项目: munkim/rn-babi
if len(lines)%batch_size==0:
    num_batches = int(len(lines)/batch_size)
else:
    num_batches = int(len(lines)/batch_size)

# load vocabulary
word2idx = np.load('word2idx.npy').item()
idx2word = np.load('idx2word.npy').item()
vocab_size = len(word2idx)

if startoff>0:
    rn = torch.load('saved/rn_qa%d_epoch_%d_acc_0.970.pth' %(qa,startoff))
else:
    rn = RN(vocab_size, embed_size, en_hidden_size, mlp_hidden_size)
if torch.cuda.is_available():
    rn = rn.cuda()

total = 0
correct = 0

# training
for epoch in range(1):
    random.shuffle(lines) # shuffle lines
    for i in range(num_batches):
        batch = lines[i*batch_size:(i+1)*batch_size]
        S,Q,A = from_batch(batch)
        out = rn(S,Q)
        O = torch.max(out,1)[1].cpu().data.numpy().squeeze()
        score = np.array(O==A,int)
        total += len(score)
        correct += sum(score)
示例#6
0
if args.cuda:
    torch.cuda.manual_seed(args.seed)

# if args.model=='CNN_MLP':
#   model = CNN_MLP(args)
# else:
model = RN(args)

model_dirs = './model'
bs = args.batch_size
input_img = torch.FloatTensor(bs, 3, 75, 75)
input_qst = torch.FloatTensor(bs, 11)
label = torch.LongTensor(bs)

if args.cuda:
    model.cuda()
    input_img = input_img.cuda()
    input_qst = input_qst.cuda()
    label = label.cuda()

input_img = Variable(input_img)
input_qst = Variable(input_qst)
label = Variable(label)


def tensor_data(data, i):
    img = torch.from_numpy(np.asarray(data[0][bs * i:bs * (i + 1)]))
    qst = torch.from_numpy(np.asarray(data[1][bs * i:bs * (i + 1)]))
    ans = torch.from_numpy(np.asarray(data[2][bs * i:bs * (i + 1)]))

    input_img.data.resize_(img.size()).copy_(img)
def main(args):
    #load hyperparameters from configuration file
    with open(args.config) as config_file:
        hyp = json.load(config_file)['hyperparams'][args.model]
    #override configuration dropout
    if args.dropout > 0:
        hyp['dropout'] = args.dropout
    if args.question_injection >= 0:
        hyp['question_injection_position'] = args.question_injection

    print('Loaded hyperparameters from configuration {}, model: {}: {}'.format(
        args.config, args.model, hyp))

    args.model_dirs = '{}/model_{}_drop{}_bstart{}_bstep{}_bgamma{}_bmax{}_lrstart{}_'+ \
                      'lrstep{}_lrgamma{}_lrmax{}_invquests-{}_clipnorm{}_glayers{}_qinj{}_fc1{}_fc2{}'
    args.model_dirs = args.model_dirs.format(
        args.exp_dir, args.model, hyp['dropout'], args.batch_size,
        args.bs_step, args.bs_gamma, args.bs_max, args.lr, args.lr_step,
        args.lr_gamma, args.lr_max, args.invert_questions, args.clip_norm,
        hyp['g_layers'], hyp['question_injection_position'], hyp['f_fc1'],
        hyp['f_fc2'])
    if not os.path.exists(args.model_dirs):
        os.makedirs(args.model_dirs)
    #create a file in this folder containing the overall configuration
    args_str = str(args)
    hyp_str = str(hyp)
    all_configuration = args_str + '\n\n' + hyp_str
    filename = os.path.join(args.model_dirs, 'config.txt')
    with open(filename, 'w') as config_file:
        config_file.write(all_configuration)

    args.features_dirs = '{}/features'.format(args.exp_dir)
    args.test_results_dir = '{}/test_results'.format(args.exp_dir)
    if not os.path.exists(args.test_results_dir):
        os.makedirs(args.test_results_dir)

    args.cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print('Building word dictionaries from all the words in the dataset...')
    dictionaries = utils.build_dictionaries(args.clevr_dir, args.exp_dir)
    print('Word dictionary completed!')

    print('Initializing CLEVR dataset...')
    clevr_dataset_train, clevr_dataset_test = initialize_dataset(
        args.clevr_dir, args.exp_dir, dictionaries, hyp['state_description'])
    print('CLEVR dataset initialized!')

    # Build the model
    args.qdict_size = len(dictionaries[0])
    args.adict_size = len(dictionaries[1])

    model = RN(args, hyp)

    if torch.cuda.device_count() > 1 and args.cuda:
        model = torch.nn.DataParallel(model)
        model.module.cuda()  # call cuda() overridden method

    if args.cuda:
        model.cuda()

    start_epoch = 1
    if args.resume:
        filename = args.resume
        if os.path.isfile(filename):
            print('==> loading checkpoint {}'.format(filename))
            checkpoint = torch.load(filename)

            #removes 'module' from dict entries, pytorch bug #3805
            if torch.cuda.device_count() == 1 and any(
                    k.startswith('module.') for k in checkpoint.keys()):
                checkpoint = {
                    k.replace('module.', ''): v
                    for k, v in checkpoint.items()
                }
            if torch.cuda.device_count() > 1 and not any(
                    k.startswith('module.') for k in checkpoint.keys()):
                checkpoint = {'module.' + k: v for k, v in checkpoint.items()}

            model.load_state_dict(checkpoint)
            print('==> loaded checkpoint {}'.format(filename))
            start_epoch = int(
                re.match(r'.*epoch_(\d+).pth', args.resume).groups()[0]) + 1

    if args.conv_transfer_learn:
        if os.path.isfile(args.conv_transfer_learn):
            # TODO: there may be problems caused by pytorch issue #3805 if using DataParallel

            print('==> loading conv layer from {}'.format(
                args.conv_transfer_learn))
            # pretrained dict is the dictionary containing the already trained conv layer
            pretrained_dict = torch.load(args.conv_transfer_learn)

            if torch.cuda.device_count() == 1:
                conv_dict = model.conv.state_dict()
            else:
                conv_dict = model.module.conv.state_dict()

            # filter only the conv layer from the loaded dictionary
            conv_pretrained_dict = {
                k.replace('conv.', '', 1): v
                for k, v in pretrained_dict.items() if 'conv.' in k
            }

            # overwrite entries in the existing state dict
            conv_dict.update(conv_pretrained_dict)

            # load the new state dict
            if torch.cuda.device_count() == 1:
                model.conv.load_state_dict(conv_dict)
                params = model.conv.parameters()
            else:
                model.module.conv.load_state_dict(conv_dict)
                params = model.module.conv.parameters()

            # freeze the weights for the convolutional layer by disabling gradient evaluation
            # for param in params:
            #     param.requires_grad = False

            print("==> conv layer loaded!")
        else:
            print('Cannot load file {}'.format(args.conv_transfer_learn))

    progress_bar = trange(start_epoch, args.epochs + 1)
    if args.test:
        # perform a single test
        print('Testing epoch {}'.format(start_epoch))
        _, clevr_test_loader = reload_loaders(clevr_dataset_train,
                                              clevr_dataset_test,
                                              args.batch_size,
                                              args.test_batch_size,
                                              hyp['state_description'])
        test(clevr_test_loader, model, start_epoch, dictionaries, args)
    else:
        bs = args.batch_size

        # perform a full training
        #TODO: find a better solution for general lr scheduling policies
        candidate_lr = args.lr * args.lr_gamma**(start_epoch -
                                                 1 // args.lr_step)
        lr = candidate_lr if candidate_lr <= args.lr_max else args.lr_max

        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=lr,
                               weight_decay=1e-4)
        # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, min_lr=1e-6, verbose=True)
        scheduler = lr_scheduler.StepLR(optimizer,
                                        args.lr_step,
                                        gamma=args.lr_gamma)
        scheduler.last_epoch = start_epoch
        print('Training ({} epochs) is starting...'.format(args.epochs))
        for epoch in progress_bar:
            if ((args.bs_max > 0 and bs < args.bs_max)
                    or args.bs_max < 0) and (epoch % args.bs_step == 0
                                             or epoch == start_epoch):
                bs = math.floor(args.batch_size *
                                (args.bs_gamma**(epoch // args.bs_step)))
                if bs > args.bs_max and args.bs_max > 0:
                    bs = args.bs_max
                clevr_train_loader, clevr_test_loader = reload_loaders(
                    clevr_dataset_train, clevr_dataset_test, bs,
                    args.test_batch_size, hyp['state_description'])

                #restart optimizer in order to restart learning rate scheduler
                #for param_group in optimizer.param_groups:
                #    param_group['lr'] = args.lr
                #scheduler = lr_scheduler.CosineAnnealingLR(optimizer, step, min_lr)
                print('Dataset reinitialized with batch size {}'.format(bs))

            if ((args.lr_max > 0 and scheduler.get_lr()[0] < args.lr_max)
                    or args.lr_max < 0):
                scheduler.step()

            print('Current learning rate: {}'.format(
                optimizer.param_groups[0]['lr']))

            # TRAIN
            progress_bar.set_description('TRAIN')
            train(clevr_train_loader, model, optimizer, epoch, args)

            # TEST
            progress_bar.set_description('TEST')
            test(clevr_test_loader, model, epoch, dictionaries, args)

            # SAVE MODEL
            filename = 'RN_epoch_{:02d}.pth'.format(epoch)
            torch.save(model.state_dict(),
                       os.path.join(args.model_dirs, filename))
示例#8
0
class Task():
    def __init__(self, args):
        print '#' * 60
        print ' ' * 20 + '    Task Created    ' + ' ' * 20
        print '#' * 60

        ######################################################################################################
        # Parameters
        self.batchSize = args.batchSize
        self.lr = args.lr
        self.weightDecay = 1e-4

        self.objNumMax = 60
        self.wordEmbeddingDim = 64
        self.lstmHiddenDim = 128
        self.instructionLength = 10
        self.pinMemory = True
        self.dropout = False

        self.epoch = args.epoch
        self.epoch_i = 0

        self.batchPrint = 100
        self.batchModelSave = args.batchModelSave
        self.checkPoint = args.checkPoint

        # Path
        self.scanListTrain = '../data/scan_list_train.txt'
        self.scanListTest = '../data/scan_list_test.txt'
        self.datasetPath = '../generated_data'
        self.logPath = args.logPath

        # Dataset
        self.tokenizer = Tokenizer(encoding_length=self.instructionLength)
        self.trainDataset = DatasetGenerator(scanListPath=self.scanListTrain,
                                             datasetPath=self.datasetPath)
        self.testDataset = DatasetGenerator(scanListPath=self.scanListTest,
                                            datasetPath=self.datasetPath)

        # build vocabulary from all instructions in the training dataset
        self.tokenizer.build_vocab_from_dataset(self.trainDataset)

        # DataLoader
        self.trainDataLoader = DataLoader(dataset=self.trainDataset,
                                          shuffle=True,
                                          batch_size=self.batchSize,
                                          num_workers=12,
                                          pin_memory=self.pinMemory)
        self.testDataLoader = DataLoader(dataset=self.testDataset,
                                         shuffle=False,
                                         batch_size=self.batchSize,
                                         num_workers=12,
                                         pin_memory=self.pinMemory)
        # calculate batch numbers
        self.trainBatchNum = int(
            np.ceil(len(self.trainDataset) / float(self.batchSize)))
        self.testBatchNum = int(
            np.ceil(len(self.testDataset) / float(self.batchSize)))

        # Create model
        self.RN = RN(batch_size=self.batchSize,
                     num_objects=self.objNumMax,
                     vocab_size=self.tokenizer.get_vocal_length(),
                     embedding_size=self.wordEmbeddingDim,
                     hidden_size=self.lstmHiddenDim,
                     padding_idx=1,
                     dropout=self.dropout)

        # Run task on all available GPUs
        if torch.cuda.is_available():
            if torch.cuda.device_count() > 1:
                print("Use ", torch.cuda.device_count(), " GPUs!")
                self.RN = nn.DataParallel(self.RN)
            self.RN = self.RN.cuda()
            print 'Model Created on GPUs.'

        # Optermizer
        self.optimizer = optim.Adam(self.RN.parameters(),
                                    lr=self.lr,
                                    betas=(0.9, 0.999),
                                    eps=1e-08,
                                    weight_decay=self.weightDecay)

        # Scheduler
        self.scheduler = ReduceLROnPlateau(self.optimizer,
                                           factor=0.1,
                                           patience=10,
                                           mode='min')

        # Loss Function
        self.loss = torch.nn.MSELoss()

        # Load model given a checkPoint
        if self.checkPoint != "":
            self.load(self.checkPoint)

        # create TensorboardX record
        self.writer = SummaryWriter(
            comment='word_embedding_64_lstm_hidden_state_128')
        self.stepCnt_train = 1
        self.stepCnt_test = 1

    def train(self):

        print 'Training task begin.'
        print '----Batch Size: %d' % self.batchSize
        print '----Learning Rate: %f' % (self.lr)
        print '----Epoch: %d' % self.epoch
        print '----Log Path: %s' % self.logPath

        for self.epoch_i in range(self.epoch):

            # if self.epoch_i == 0:
            #     self.save(batchIdx=0)  # Test the save function

            if self.epoch_i != 0:
                try:
                    self.map = self.map.eval()
                    self.test()
                except Exception, e:
                    print e

            self.RN = self.RN.train()
            self.epochTrain()
            self.save()