예제 #1
0
파일: trainer.py 프로젝트: Willamjie/SCAN-1
    def train(self):
        model = SCAN(self.params)
        model.apply(init_xavier)
        model.load_state_dict(torch.load('models/model_weights_5.t7'))
        loss_function = MarginLoss(self.params.margin)
        if torch.cuda.is_available():
            model = model.cuda()
            loss_function = loss_function.cuda()
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=self.params.lr,
                                     weight_decay=self.params.wdecay)
        try:
            prev_best = 0
            for epoch in range(self.params.num_epochs):
                iters = 1
                losses = []
                start_time = timer()
                num_of_mini_batches = len(
                    self.data_loader.train_ids) // self.params.batch_size
                for (caption, mask, image, neg_cap, neg_mask,
                     neg_image) in tqdm(self.data_loader.training_data_loader):

                    # Sample according to hard negative mining
                    caption, mask, image, neg_cap, neg_mask, neg_image = self.data_loader.hard_negative_mining(
                        model, caption, mask, image, neg_cap, neg_mask,
                        neg_image)
                    model.train()
                    optimizer.zero_grad()
                    # forward pass.
                    similarity = model(to_variable(caption), to_variable(mask),
                                       to_variable(image), False)
                    similarity_neg_1 = model(to_variable(neg_cap),
                                             to_variable(neg_mask),
                                             to_variable(image), False)
                    similarity_neg_2 = model(to_variable(caption),
                                             to_variable(mask),
                                             to_variable(neg_image), False)

                    # Compute the loss, gradients, and update the parameters by calling optimizer.step()
                    loss = loss_function(similarity, similarity_neg_1,
                                         similarity_neg_2)
                    loss.backward()
                    losses.append(loss.data.cpu().numpy())
                    if self.params.clip_value > 0:
                        torch.nn.utils.clip_grad_norm(model.parameters(),
                                                      self.params.clip_value)
                    optimizer.step()

                    #                     sys.stdout.write("[%d/%d] :: Training Loss: %f   \r" % (
                    #                         iters, num_of_mini_batches, np.asscalar(np.mean(losses))))
                    #                     sys.stdout.flush()
                    iters += 1

                if epoch + 1 % self.params.step_size == 0:
                    optim_state = optimizer.state_dict()
                    optim_state['param_groups'][0]['lr'] = optim_state[
                        'param_groups'][0]['lr'] / self.params.gamma
                    optimizer.load_state_dict(optim_state)

                torch.save(
                    model.state_dict(), self.params.model_dir +
                    '/model_weights_{}.t7'.format(epoch + 1))

                # Calculate r@k after each epoch
                if (epoch + 1) % self.params.validate_every == 0:
                    r_at_1, r_at_5, r_at_10 = self.evaluator.recall(
                        model, is_test=False)

                    print(
                        "Epoch {} : Training Loss: {:.5f}, R@1 : {}, R@5 : {}, R@10 : {}, Time elapsed {:.2f} mins"
                        .format(epoch + 1, np.asscalar(np.mean(losses)),
                                r_at_1, r_at_5, r_at_10,
                                (timer() - start_time) / 60))
                    if r_at_1 > prev_best:
                        print("Recall at 1 increased....saving weights !!")
                        prev_best = r_at_1
                        torch.save(
                            model.state_dict(), self.params.model_dir +
                            'best_model_weights_{}.t7'.format(epoch + 1))
                else:
                    print("Epoch {} : Training Loss: {:.5f}".format(
                        epoch + 1, np.asscalar(np.mean(losses))))
        except KeyboardInterrupt:
            print("Interrupted.. saving model !!!")
            torch.save(model.state_dict(),
                       self.params.model_dir + '/model_weights_interrupt.t7')
def main():
    # Hyper Parameters
    
    opt = opts.parse_opt()

    device_id = opt.gpuid
    device_count = len(str(device_id).split(","))
    #assert device_count == 1 or device_count == 2
    print("use GPU:", device_id, "GPUs_count", device_count, flush=True)
    os.environ['CUDA_VISIBLE_DEVICES']=str(device_id)
    device_id = 0
    torch.cuda.set_device(0)

    # Load Vocabulary Wrapper
    vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    # Load data loaders
    train_loader, val_loader = data.get_loaders(
        opt.data_name, vocab, opt.batch_size, opt.workers, opt)

    # Construct the model
    model = SCAN(opt)
    model.cuda()
    model = nn.DataParallel(model)

     # Loss and Optimizer
    criterion = ContrastiveLoss(opt=opt, margin=opt.margin, max_violation=opt.max_violation)
    mse_criterion = nn.MSELoss(reduction="batchmean")
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate)

    # optionally resume from a checkpoint
    if not os.path.exists(opt.model_name):
        os.makedirs(opt.model_name)
    start_epoch = 0
    best_rsum = 0

    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch']
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})"
                  .format(opt.resume, start_epoch, best_rsum))
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))
    evalrank(model.module, val_loader, opt)

    print(opt, flush=True)
    
    # Train the Model
    for epoch in range(start_epoch, opt.num_epochs):
        message = "epoch: %d, model name: %s\n" % (epoch, opt.model_name)
        log_file = os.path.join(opt.logger_name, "performance.log")
        logging_func(log_file, message)
        print("model name: ", opt.model_name, flush=True)
        adjust_learning_rate(opt, optimizer, epoch)
        run_time = 0
        for i, (images, captions, lengths, masks, ids, _) in enumerate(train_loader):
            start_time = time.time()
            model.train()

            optimizer.zero_grad()

            if device_count != 1:
                images = images.repeat(device_count,1,1)

            score = model(images, captions, lengths, masks, ids)
            loss = criterion(score)

            loss.backward()
            if opt.grad_clip > 0:
                clip_grad_norm_(model.parameters(), opt.grad_clip)
            optimizer.step()
            run_time += time.time() - start_time
            # validate at every val_step
            if i % 100 == 0:
                log = "epoch: %d; batch: %d/%d; loss: %.4f; time: %.4f" % (epoch, 
                            i, len(train_loader), loss.data.item(), run_time / 100)
                print(log, flush=True)
                run_time = 0
            if (i + 1) % opt.val_step == 0:
                evalrank(model.module, val_loader, opt)

        print("-------- performance at epoch: %d --------" % (epoch))
        # evaluate on validation set
        rsum = evalrank(model.module, val_loader, opt)
        #rsum = -100
        filename = 'model_' + str(epoch) + '.pth.tar'
        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        save_checkpoint({
            'epoch': epoch + 1,
            'model': model.state_dict(),
            'best_rsum': best_rsum,
            'opt': opt,
        }, is_best, filename=filename, prefix=opt.model_name + '/')
예제 #3
0
        if (epoch % save_epoch == 0) or (epoch == training_epochs - 1):
            torch.save(scan.state_dict(),
                       '{}/scan_epoch_{}.pth'.format(exp, epoch))


data_manager = DataManager()
data_manager.prepare()

dae = DAE()
vae = VAE()
scan = SCAN()
if use_cuda:
    dae.load_state_dict(torch.load('save/dae/dae_epoch_2999.pth'))
    vae.load_state_dict(torch.load('save/vae/vae_epoch_2999.pth'))
    scan.load_state_dict(torch.load('save/scan/scan_epoch_1499.pth'))
    dae, vae, scan = dae.cuda(), vae.cuda(), scan.cuda()
else:
    dae.load_state_dict(
        torch.load('save/dae/dae_epoch_2999.pth',
                   map_location=lambda storage, loc: storage))
    vae.load_state_dict(
        torch.load('save/vae/vae_epoch_2999.pth',
                   map_location=lambda storage, loc: storage))
    scan.load_state_dict(
        torch.load(exp + '/' + opt.load,
                   map_location=lambda storage, loc: storage))

if opt.train:
    scan_optimizer = optim.Adam(scan.parameters(), lr=1e-4, eps=1e-8)
    train_scan(dae, vae, scan, data_manager, scan_optimizer)