예제 #1
0
def train(train_loader, validation_loader, net, embeddings):
    """
    Trains the network with a given training data loader and validation data
    loader.
    """
    optimizer = Adadelta(net.parameters())
    evaluate(validation_loader, net, 'validation', log=False)
    prev_best_acc = 0

    for i in range(10):
        print('Epoch:', i)
        net.train()

        avg_loss = 0
        avg_acc = 0

        for i, (vectors, targets) in enumerate(train_loader):
            optimizer.zero_grad()
            logits = net(vectors)
            loss = F.cross_entropy(logits, targets)
            loss.backward()
            optimizer.step()

            corrects = float((torch.max(logits, 1)[1].view(
                targets.size()).data == targets.data).sum())
            accuracy = 100.0 * corrects / batch_size
            avg_loss += float(loss)
            avg_acc += accuracy

        avg_loss /= i + 1
        avg_acc /= i + 1

        logger('training', 'loss', avg_loss)
        logger('training', 'accuracy', avg_acc)

        acc = evaluate(validation_loader, net, 'validation')

        if acc > prev_best_acc:
            torch.save(net.state_dict(), params_file)
            prev_best_acc = acc
예제 #2
0
def train(training_data_file, valid_data_file, super_batch_size, tokenizer, mode, kw, p_key, model1, device, model2, model3, \
            batch_size, num_epoch, gradient_accumulation_steps, lr1, lr2, lambda_, valid_critic, early_stop):
    '''Train three models
    
    Train models through bundles
    
    Args:
        training_data_file (list) : training data json file, raw json file used to load data
        super_batch_size (int) : how many samples will be loaded into memory at once
        tokenizer : SentencePiece tokenizer used to obtain the token ids
        mode (str): mode of the passage format, coule be a list (processed) or a long string (unprocessed).
        kw (str) : the key word map to the passage in each data dictionary. Defaults to 'abstract'
        p_key (str) : the key word to search for specific passage. Default to 'title'
        model1 (nn.DataParallel) : local dependency encoder
        device (torch.device): The device which models and data are on.
        model2 (nn.Module): global coherence encoder
        model3 (nn.Module): attention decoder
        batch_size (int): Defaults to 4.
        num_epoch (int): Defaults to 1.
        gradient_accumulation_steps (int): Defaults to 1. 
        lr (float): Defaults to 1e-4. The Start learning rate.
        lambda_ (float): Defaults to 0.01. Balance factor for param nomalization.
        valid_critic (bool) : what critic to use when early stop evaluation. Default to 5 
        early_stop (int) : set the early stop boundary. Default to 5 

    '''

    # Prepare optimizer for Sys1
    param_optimizer_bert = list(model1.named_parameters())
    param_optimizer_others = list(model2.named_parameters()) + list(
        model3.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    # We tend to fix the embedding. Temeporarily we doesn't find the embedding layer
    optimizer_grouped_parameters_bert = [{
        'params': [
            p for n, p in param_optimizer_bert
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        lambda_
    }, {
        'params': [
            p for n, p in param_optimizer_bert
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]

    optimizer_grouped_parameters_others = [{
        'params': [
            p for n, p in param_optimizer_others
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        lambda_
    }, {
        'params': [
            p for n, p in param_optimizer_others
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    # We shall adda  module to count the num of parameters here
    critic = nn.NLLLoss(reduction='none')

    line_num = int(os.popen("wc -l " + training_data_file).read().split()[0])
    global_step = 0  # global step
    opt1 = BertAdam(optimizer_grouped_parameters_bert,
                    lr=lr1,
                    warmup=0.1,
                    t_total=line_num / batch_size * num_epoch)  # optimizer 1
    # opt = Adam(optimizer_grouped_parameter, lr=lr)
    opt2 = Adadelta(optimizer_grouped_parameters_others, lr=lr2, rho=0.95)
    model1.to(device)  #
    model1.train()  #
    model2.to(device)  #
    model2.train()  #
    model3.to(device)  #
    model3.train()  #
    warmed = True
    for epoch in trange(num_epoch, desc='Epoch'):

        smooth_mean = WindowMean()
        opt1.zero_grad()
        opt2.zero_grad()

        for superbatch, line_num in load_superbatch(training_data_file,
                                                    super_batch_size):
            bundles = []

            for data in superbatch:
                try:
                    bundles.append(
                        convert_passage_to_samples_bundle(
                            tokenizer, data, mode, kw, p_key))

                except:
                    print_exc()

            num_batch, dataloader = homebrew_data_loader(bundles,
                                                         batch_size=batch_size)

            tqdm_obj = tqdm(dataloader, total=num_batch)
            num_steps = line_num  #
            for step, batch in enumerate(tqdm_obj):
                try:
                    #batch[0] = batch[0].to(device)
                    #batch[1] = batch[1].to(device)
                    #batch[2] = batch[2].to(device)
                    batch = tuple(t for t in batch)
                    log_prob_loss, pointers_output, ground_truth = calculate_loss(
                        batch, model1, model2, model3, device, critic)
                    # here we need to add code to cal rouge-w and acc
                    rouge_ws = []
                    accs = []
                    ken_taus = []
                    pmrs = []
                    for pred, true in zip(pointers_output, ground_truth):
                        rouge_ws.append(rouge_w(pred, true))
                        accs.append(acc(pred, true))
                        ken_taus.append(kendall_tau(pred, true))
                        pmrs.append(pmr(pred, true))

                    log_prob_loss.backward()

                    # ******** In the following code we gonna edit it and made early stop ************

                    if (step + 1) % gradient_accumulation_steps == 0:
                        # modify learning rate with special warm up BERT uses. From BERT pytorch examples
                        lr_this_step = lr1 * warmup_linear(
                            global_step / num_steps, warmup=0.1)
                        for param_group in opt1.param_groups:
                            param_group['lr'] = lr_this_step
                        global_step += 1

                        opt2.step()
                        opt2.zero_grad()
                        smooth_mean_loss = smooth_mean.update(
                            log_prob_loss.item())
                        tqdm_obj.set_description(
                            '{}: {:.4f}, {}: {:.4f}, smooth_mean_loss: {:.4f}'.
                            format('accuracy', np.mean(accs), 'rough-w',
                                   np.mean(rouge_ws), smooth_mean_loss))
                        # During warming period, model1 is frozen and model2 is trained to normal weights
                        if smooth_mean_loss < 1.0 and step > 100:  # ugly manual hyperparam
                            warmed = True
                        if warmed:
                            opt1.step()
                        opt1.zero_grad()
                        if step % 1000 == 0:
                            output_model_file = './models/bert-base-cased.bin.tmp'
                            saved_dict = {
                                'params1': model1.module.state_dict()
                            }
                            saved_dict['params2'] = model2.state_dict()
                            saved_dict['params3'] = model3.state_dict()
                            torch.save(saved_dict, output_model_file)

                except Exception as err:
                    traceback.print_exc()
                    exit()
                    # if mode == 'list':
                    #     print(batch._id)

        if epoch < 5:
            best_score = 0
            continue

        with torch.no_grad():
            print('valid..............')

            valid_critic_dict = {
                'rouge-w': rouge_w,
                'acc': acc,
                'ken-tau': kendall_tau,
                'pmr': pmr
            }

            for superbatch, _ in load_superbatch(valid_data_file,
                                                 super_batch_size):
                bundles = []

                for data in superbatch:
                    try:
                        bundles.append(
                            convert_passage_to_samples_bundle(
                                tokenizer, data, mode, kw, p_key))
                    except:
                        print_exc()

                num_batch, valid_dataloader = homebrew_data_loader(
                    bundles, batch_size=1)

                valid_value = []
                for step, batch in enumerate(valid_dataloader):
                    try:
                        batch = tuple(t for idx, t in enumerate(batch))
                        pointers_output, ground_truth \
                            = dev_test(batch, model1, model2, model3, device)
                        valid_value.append(valid_critic_dict[valid_critic](
                            pointers_output, ground_truth))

                    except Exception as err:
                        traceback.print_exc()
                        # if mode == 'list':
                        #     print(batch._id)

                score = np.mean(valid_value)
            print('epc:{}, {} : {:.2f} best : {:.2f}\n'.format(
                epoch, valid_critic, score, best_score))

            if score > best_score:
                best_score = score
                best_iter = epoch

                print('Saving model to {}'.format(
                    output_model_file))  # save model structure
                saved_dict = {
                    'params1': model1.module.state_dict()
                }  # save parameters
                saved_dict['params2'] = model2.state_dict()  # save parameters
                saved_dict['params3'] = model3.state_dict()
                torch.save(saved_dict, output_model_file)  #

                # print('save best model at epc={}'.format(epc))
                # checkpoint = {'model': model.state_dict(),
                #             'args': args,
                #             'loss': best_score}
                # torch.save(checkpoint, '{}/{}.best.pt'.format(args.model_path, args.model))

            if early_stop and (epoch - best_iter) >= early_stop:
                print('early stop at epc {}'.format(epoch))
                break
예제 #3
0
    def train(self):
        os.environ["CUDA_VISIBLE_DEVICES"] = '0'
        device_ids = [0]
        self.classifier = classifier()
        get_paprams(self.classifier)
        get_paprams(self.classifier.base)
        # data_set_eval = my_dataset(eval=True)
        # data_set = my_dataset_10s()
        # data_set_test = my_dataset_10s()
        data_set = my_dataset_10s_smote()
        data_set_test = my_dataset_10s_smote(test=True, all_data=data_set.all_data, all_label=data_set.all_label,
                                             index_=data_set.index)
        # data_set_eval = my_dataset_10s(eval=True)
        # data_set_combine = my_dataset(combine=True)
        batch = 300
        totoal_epoch = 2000
        print('batch:{}'.format(batch))
        # self.evaluation = evaluation
        data_loader = DataLoader(data_set, batch, shuffle=True, collate_fn=detection_collate)
        data_loader_test = DataLoader(data_set_test, batch, False, collate_fn=detection_collate)
        # data_loader_eval = DataLoader(data_set_eval, batch, False, collate_fn=detection_collate)
        self.classifier = self.classifier.cuda()
        self.classifier = DataParallel(self.classifier, device_ids=device_ids)
        optim = Adadelta(self.classifier.parameters(), 0.1, 0.9, weight_decay=1e-5)

        self.cretion = smooth_focal_weight()

        self.classifier.apply(weights_init)
        start_time = time.time()
        count = 0
        epoch = -1
        while 1:
            epoch += 1
            runing_losss = [0] * 5
            for data in data_loader:
                loss = [0] * 5
                y = data[1].cuda()
                x = data[0].cuda()
                optim.zero_grad()

                weight = torch.Tensor([0.5, 2, 0.5, 2]).cuda()

                inputs, targets_a, targets_b, lam = mixup_data(x, y)
                predict = self.classifier(x)
                ############################3

                loss_func = mixup_criterion(targets_a, targets_b, lam, weight)
                loss5 = loss_func(self.cretion, predict[0])
                loss4 = loss_func(self.cretion, predict[1]) * 0.4
                loss3 = loss_func(self.cretion, predict[2]) * 0.3
                loss2 = loss_func(self.cretion, predict[3]) * 0.2
                loss1 = loss_func(self.cretion, predict[4]) * 0.1

                tmp = loss5 + loss4 + loss3 + loss2 + loss1

                # tmp = sum(loss)
                tmp.backward()
                optim.step()
                for i in range(5):
                    # runing_losss[i] += (loss[i].item())
                    runing_losss[i] += (tmp.item())

                count += 1
                # torch.cuda.empty_cache()
            end_time = time.time()
            print(
                "epoch:{a}: loss:{b} spend_time:{c} time:{d}".format(a=epoch, b=sum(runing_losss),
                                                                     c=int(end_time - start_time),
                                                                     d=time.asctime()))
            start_time = end_time

            # vis.line(np.asarray([optim.param_groups[0]['lr']]), np.asarray([epoch]), win="lr", update='append',
            #          opts=dict(title='lr'))
            # if (epoch > 20):
            #     runing_losss = np.asarray(runing_losss).reshape(1, 5)

            # vis.line(runing_losss,
            #          np.asarray([epoch] * 5).reshape(1, 5), win="loss-epoch", update='append',
            #          opts=dict(title='loss', legend=['loss1', 'loss2', 'loss3', 'loss4', 'loss5', 'loss6']))
            save(self.classifier.module.base.state_dict(),
                 str(epoch) + 'base_c2.p')
            save(self.classifier.module.state_dict(),
                 str(epoch) + 'base_all_c2.p')
            # print('eval:{}'.format(time.asctime(time.localtime(time.time()))))
            self.classifier.eval()
            # self.evaluation(self.classifier, data_loader_eval)
            # print('test:{}'.format(time.asctime(time.localtime(time.time()))))
            # self.evaluation(self.classifier, data_loader_eval, epoch)
            self.evaluation(self.classifier, data_loader_test, epoch)
            # self.evaluation(self.classifier, data_loader, epoch)

            # print('combine:{}'.format(time.asctime(time.localtime(time.time()))))
            # evaluation(self.classifier, data_loader_combine)
            self.classifier.train()
            if epoch % 10 == 0:
                adjust_learning_rate(optim, 0.9, epoch, totoal_epoch, 0.1)
예제 #4
0
def train(pretrain=PRETRAIN):
    logging.debug('pretrain:{}'.format(pretrain))
    if DEVICE == 'cuda':
        if torch.cuda.is_available() == False:
            logging.error("can't find a GPU device")
            pdb.set_trace()
    #model=DenseLSTM(NUM_CLASS)
    #model=VGGLSTM(NUM_CLASS)
    #model=DenseCNN(NUM_CLASS)
    #model=VGGFC(NUM_CLASS)
    model = ResNetLSTM(NUM_CLASS)
    if os.path.exists(MODEL_PATH) == False:
        os.makedirs(MODEL_PATH)
    if os.path.exists(PATH + DICTIONARY_NAME) == False:
        logging.error("can't find the dictionary")
        pdb.set_trace()
    with open(PATH + DICTIONARY_NAME, 'r') as f:
        dictionary = json.load(f)
    if pretrain == True:
        model.load_state_dict(
            torch.load(MODEL_PATH + MODEL_NAME, map_location=DEVICE))
    model.to(DEVICE).train()
    model.register_backward_hook(backward_hook)  #transforms.Resize((32,400))
    dataset = ICDARRecTs_2DataSet(IMAGE_PATH,
                                  dictionary,
                                  BATCH_SIZE,
                                  img_transform=transforms.Compose([
                                      transforms.ColorJitter(brightness=0.5,
                                                             contrast=0.5,
                                                             saturation=0.5,
                                                             hue=0.3),
                                      transforms.ToTensor(),
                                      transforms.Normalize(
                                          (0.485, 0.456, 0.406),
                                          (0.229, 0.224, 0.225))
                                  ]))
    dataloader = DataLoader(dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=True,
                            num_workers=4,
                            drop_last=False)  #collate_fn=dataset.collate
    #optimizer=Adam(model.parameters(),lr=LR,betas=(0.9,0.999),weight_decay=0)
    optimizer = Adadelta(model.parameters(), lr=0.01, rho=0.9, weight_decay=0)
    criterion = CTCLoss(blank=0)
    length = len(dataloader)
    max_accuracy = 0
    if os.path.exists('max_accuracy.txt') == True:
        with open('max_accuracy.txt', 'r') as f:
            max_accuracy = float(f.read())
    for epoch in range(EPOCH):
        epoch_time = datetime.now()
        epoch_correct = 0
        epoch_loss = 0
        min_loss = 100
        for step, data in enumerate(dataloader):
            step_time = datetime.now()
            imgs, names, label_size, img_name = data
            #print(names,label_size)
            logging.debug("imgs' size:{}".format(imgs.size()))
            imgs = Variable(imgs, requires_grad=True).to(DEVICE)
            label, batch_label = dataset.transform_label(batch_name=names)
            label = Variable(label).to(DEVICE)
            label_size = Variable(label_size).to(DEVICE)
            preds = model(imgs)
            logging.debug("preds size:{}".format(preds.size()))
            preds_size = Variable(
                torch.LongTensor([preds.size(0)] * BATCH_SIZE)).to(DEVICE)
            loss = criterion(preds, label, preds_size, label_size)
            epoch_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
            optimizer.step()
            if min_loss > loss.item():
                min_loss = loss.item()
                torch.save(model.state_dict(), MODEL_PATH + MODEL_NAME)
            num_same = if_same(preds.cpu().data, batch_label)
            epoch_correct += num_same
            logging.debug(
                "Epoch:{}|length:{}|step:{}|num_same:{}|loss:{:.4f}|min loss:{:.4f}"
                .format(epoch, length, step, num_same, loss.item(), min_loss))
            logging.debug("the time of one step:{}".format(datetime.now() -
                                                           step_time))
            if step % 100 == 0:
                clear_output(wait=True)
        accuracy = epoch_correct / (length) * BATCH_SIZE
        if accuracy > max_accuracy:
            max_accuracy = accuracy
            with open('max_accuracy.txt', 'w') as f:
                f.write(str(max_accuracy))
            torch.save(model.state_dict(), MODEL_PATH + MODEL_NAME)
            torch.save(model.state_dict(),
                       MODEL_PATH + 'optimal' + str(max_accuracy) + MODEL_NAME)
        mean_loss = epoch_loss / length
        logging.info(
            'Epoch:{}|accuracy:{}|mean loss:{}|the time of one epoch:{}|max accuracy:{}'
            .format(epoch, accuracy, mean_loss,
                    datetime.now() - epoch_time, max_accuracy))
        with open('accuracy.txt', 'a+') as f:
            f.write(
                'Epoch:{}|accuracy:{}|mean loss:{}|the time of one epoch:{}|max accuracy:{}\n'
                .format(epoch, accuracy, mean_loss,
                        datetime.now() - epoch_time, max_accuracy))
예제 #5
0
        latent_samples = torch.randn(bs, z_dim)
        d_gen_input = Variable(latent_samples)
        if GPU_NUMS > 0:
            d_gen_input = d_gen_input.cuda()
        d_fake_data = generator(
            d_gen_input).detach()  # detach to avoid training G on these labels
        d_fake_decision = discriminator(d_fake_data)
        labels = Variable(torch.zeros(bs))
        if GPU_NUMS > 0:
            labels = labels.cuda()
        d_fake_loss = criterion(d_fake_decision, labels)  # zeros = fake

        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()

        d_optimizer.step(
        )  # Only optimizes D's parameters; changes based on stored gradients from backward()

    for g_index in range(1):
        # 2. Train G on D's response (but DO NOT train D on these labels)
        generator.zero_grad()

        latent_samples = torch.randn(bs, z_dim)
        g_gen_input = Variable(latent_samples)
        if GPU_NUMS > 0:
            g_gen_input = g_gen_input.cuda()
        g_fake_data = generator(g_gen_input)
        g_fake_decision = discriminator(g_fake_data)
        labels = Variable(torch.ones(bs))
        if GPU_NUMS > 0:
            labels = labels.cuda()
        g_loss = criterion(
예제 #6
0
def train_predict(batch_size=100, epochs=10, topk=30, L2=1e-8):
    patients = getTrainData(4000000)  # patients × visits × medical_code

    patients_num = len(patients)
    train_patient_num = int(patients_num * 0.8)
    patients_train = patients[0:train_patient_num]
    test_patient_num = patients_num - train_patient_num
    patients_test = patients[train_patient_num:]

    train_batch_num = int(np.ceil(float(train_patient_num) / batch_size))
    test_batch_num = int(np.ceil(float(test_patient_num) / batch_size))

    model = Dipole(input_dim=3393,
                   day_dim=200,
                   rnn_hiddendim=300,
                   output_dim=283)

    params = list(model.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
    print("总参数数量和:" + str(k))

    optimizer = Adadelta(model.parameters(), lr=1, weight_decay=L2)
    loss_mce = nn.BCELoss(reduction='sum')
    model = model.cuda(device=1)

    for epoch in range(epochs):
        starttime = time.time()
        # 训练
        model.train()
        all_loss = 0.0
        for batch_index in range(train_batch_num):
            patients_batch = patients_train[batch_index *
                                            batch_size:(batch_index + 1) *
                                            batch_size]
            patients_batch_reshape, patients_lengths = model.padTrainMatrix(
                patients_batch)  # maxlen × n_samples × inputDimSize
            batch_x = patients_batch_reshape[0:-1]  # 获取前n-1个作为x,来预测后n-1天的值
            # batch_y = patients_batch_reshape[1:]
            batch_y = patients_batch_reshape[1:, :, :283]  # 取出药物作为y
            optimizer.zero_grad()
            # h0 = model.initHidden(batch_x.shape[1])
            batch_x = torch.tensor(batch_x, device=torch.device('cuda:1'))
            batch_y = torch.tensor(batch_y, device=torch.device('cuda:1'))
            y_hat = model(batch_x)
            mask = out_mask2(y_hat,
                             patients_lengths)  # 生成mask,用于将padding的部分输出置0
            # 通过mask,将对应序列长度外的网络输出置0
            y_hat = y_hat.mul(mask)
            batch_y = batch_y.mul(mask)
            # (seq_len, batch_size, out_dim)->(seq_len*batch_size*out_dim, 1)->(seq_len*batch_size*out_dim, )
            y_hat = y_hat.view(-1, 1).squeeze()
            batch_y = batch_y.view(-1, 1).squeeze()

            loss = loss_mce(y_hat, batch_y)
            loss.backward()
            optimizer.step()
            all_loss += loss.item()
        print("Train:Epoch-" + str(epoch) + ":" + str(all_loss) +
              " Train Time:" + str(time.time() - starttime))

        # 测试
        model.eval()
        NDCG = 0.0
        RECALL = 0.0
        DAYNUM = 0.0
        all_loss = 0.0
        gbert_pred = []
        gbert_true = []
        gbert_len = []

        for batch_index in range(test_batch_num):
            patients_batch = patients_test[batch_index *
                                           batch_size:(batch_index + 1) *
                                           batch_size]
            patients_batch_reshape, patients_lengths = model.padTrainMatrix(
                patients_batch)
            batch_x = patients_batch_reshape[0:-1]
            batch_y = patients_batch_reshape[1:, :, :283]
            batch_x = torch.tensor(batch_x, device=torch.device('cuda:1'))
            batch_y = torch.tensor(batch_y, device=torch.device('cuda:1'))
            y_hat = model(batch_x)
            mask = out_mask2(y_hat, patients_lengths)
            loss = loss_mce(y_hat.mul(mask), batch_y.mul(mask))

            all_loss += loss.item()
            y_hat = y_hat.detach().cpu().numpy()
            ndcg, recall, daynum = validation(y_hat, patients_batch,
                                              patients_lengths, topk)
            NDCG += ndcg
            RECALL += recall
            DAYNUM += daynum
            gbert_pred.append(y_hat)
            gbert_true.append(batch_y.cpu())
            gbert_len.append(patients_lengths)

        avg_NDCG = NDCG / DAYNUM
        avg_RECALL = RECALL / DAYNUM
        y_pred_all, y_true_all = batch_squeeze(gbert_pred, gbert_true,
                                               gbert_len)
        acc_container = metric_report(y_pred_all, y_true_all, 0.2)
        print("Test:Epoch-" + str(epoch) + " Loss:" + str(all_loss) +
              " Test Time:" + str(time.time() - starttime))
        print("Test:Epoch-" + str(epoch) + " NDCG:" + str(avg_NDCG) +
              " RECALL:" + str(avg_RECALL))
        print("Test:Epoch-" + str(epoch) + " Jaccard:" +
              str(acc_container['jaccard']) + " f1:" +
              str(acc_container['f1']) + " prauc:" +
              str(acc_container['prauc']) + " roauc:" +
              str(acc_container['auc']))

        print("")
예제 #7
0
def main(args):

    # ================================================
    # Preparation
    # ================================================
    args.data_dir = os.path.expanduser(args.data_dir)
    args.result_dir = os.path.expanduser(args.result_dir)
    if args.init_model_cn != None:
        args.init_model_cn = os.path.expanduser(args.init_model_cn)
    if args.init_model_cd != None:
        args.init_model_cd = os.path.expanduser(args.init_model_cd)
    if torch.cuda.is_available() == False:
        raise Exception('At least one gpu must be available.')
    else:
        gpu = torch.device('cuda:0')

    # create result directory (if necessary)
    if os.path.exists(args.result_dir) == False:
        os.makedirs(args.result_dir)
    for s in ['phase_1', 'phase_2', 'phase_3']:
        if os.path.exists(os.path.join(args.result_dir, s)) == False:
            os.makedirs(os.path.join(args.result_dir, s))

    # dataset
    trnsfm = transforms.Compose([
        transforms.Resize(args.cn_input_size),
        transforms.RandomCrop((args.cn_input_size, args.cn_input_size)),
        transforms.ToTensor(),
    ])
    print('loading dataset... (it may take a few minutes)')
    train_dset = ImageDataset(os.path.join(args.data_dir, 'train'),
                              trnsfm,
                              recursive_search=args.recursive_search)
    test_dset = ImageDataset(os.path.join(args.data_dir, 'test'),
                             trnsfm,
                             recursive_search=args.recursive_search)
    train_loader = DataLoader(train_dset,
                              batch_size=(args.bsize // args.bdivs),
                              shuffle=True)

    # compute mean pixel value of training dataset
    mpv = np.zeros(shape=(3, ))
    if args.mpv == None:
        pbar = tqdm(total=len(train_dset.imgpaths),
                    desc='computing mean pixel value for training dataset...')
        for imgpath in train_dset.imgpaths:
            img = Image.open(imgpath)
            x = np.array(img, dtype=np.float32) / 255.
            mpv += x.mean(axis=(0, 1))
            pbar.update()
        mpv /= len(train_dset.imgpaths)
        pbar.close()
    else:
        mpv = np.array(args.mpv)

    # save training config
    mpv_json = []
    for i in range(3):
        mpv_json.append(float(mpv[i]))  # convert to json serializable type
    args_dict = vars(args)
    args_dict['mpv'] = mpv_json
    with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f:
        json.dump(args_dict, f)

    # make mpv & alpha tensor
    mpv = torch.tensor(mpv.astype(np.float32).reshape(1, 3, 1, 1)).to(gpu)
    alpha = torch.tensor(args.alpha).to(gpu)

    my_writer = SummaryWriter(log_dir='log')

    # ================================================
    # Training Phase 1
    # ================================================
    model_cn = CompletionNetwork()
    if args.data_parallel:
        model_cn = DataParallel(model_cn)
    if args.init_model_cn != None:
        new = OrderedDict()
        for k, v in torch.load(args.init_model_cn, map_location='cpu').items():
            new['module.' + k] = v

        # model_cn.load_state_dict(torch.load(args.init_model_cn, map_location='cpu'))
        model_cn.load_state_dict(new)
        print('第一阶段加载模型成功!')
        # model_cn.load_state_dict(torch.load(args.init_model_cn, map_location='cpu'))
    if args.optimizer == 'adadelta':
        opt_cn = Adadelta(model_cn.parameters())
    else:
        opt_cn = Adam(model_cn.parameters())
    model_cn = model_cn.to(gpu)

    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_1)
    pbar.n = 90000
    while pbar.n < args.steps_1:

        for i, x in enumerate(train_loader):
            # forward
            x = x.to(gpu)
            mask = gen_input_mask(
                shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                hole_size=((args.hole_min_w, args.hole_max_w),
                           (args.hole_min_h, args.hole_max_h)),
                hole_area=gen_hole_area(
                    (args.ld_input_size, args.ld_input_size),
                    (x.shape[3], x.shape[2])),
                max_holes=args.max_holes,
            ).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input = torch.cat((x_mask, mask), dim=1)
            output = model_cn(input)
            loss_mse = completion_network_loss(x, output, mask)

            loss_contextual = contextual_loss(x, output, mask)
            loss = loss_mse + 0.004 * loss_contextual
            # backward
            loss.backward()
            cnt_bdivs += 1

            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0
                # optimize
                opt_cn.step()
                # clear grads
                opt_cn.zero_grad()
                # update progbar
                pbar.set_description('phase 1 | train loss: %.5f' % loss.cpu())
                pbar.update()

                my_writer.add_scalar('mse_loss', loss_mse,
                                     pbar.n * len(train_loader) + i)

                # test
                if pbar.n % args.snaperiod_1 == 0:
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(
                            shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                            hole_size=((args.hole_min_w, args.hole_max_w),
                                       (args.hole_min_h, args.hole_max_h)),
                            hole_area=gen_hole_area(
                                (args.ld_input_size, args.ld_input_size),
                                (x.shape[3], x.shape[2])),
                            max_holes=args.max_holes,
                        ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = poisson_blend(x, output, mask)
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_1',
                                               'step%d.png' % pbar.n)
                        model_cn_path = os.path.join(
                            args.result_dir, 'phase_1',
                            'model_cn_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cn.module.state_dict(),
                                       model_cn_path)
                        else:
                            torch.save(model_cn.state_dict(), model_cn_path)
                # terminate
                if pbar.n >= args.steps_1:
                    break
    pbar.close()

    # ================================================
    # Training Phase 2
    # ================================================
    model_cd = ContextDiscriminator(
        local_input_shape=(3, args.ld_input_size, args.ld_input_size),
        global_input_shape=(3, args.cn_input_size, args.cn_input_size),
        arc=args.arc,
    )
    if args.data_parallel:
        model_cd = DataParallel(model_cd)
    if args.init_model_cd != None:
        # model_cd.load_state_dict(torch.load(args.init_model_cd, map_location='cpu'))
        new = OrderedDict()
        for k, v in torch.load(args.init_model_cd, map_location='cpu').items():
            new['module.' + k] = v

        # model_cn.load_state_dict(torch.load(args.init_model_cn, map_location='cpu'))
        model_cd.load_state_dict(new)
        print('第二阶段加载模型成功!')
    if args.optimizer == 'adadelta':
        opt_cd = Adadelta(model_cd.parameters())
    else:
        opt_cd = Adam(model_cd.parameters())
    model_cd = model_cd.to(gpu)
    bceloss = BCELoss()

    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_2)
    pbar.n = 120000
    while pbar.n < args.steps_2:
        for x in train_loader:

            # fake forward
            x = x.to(gpu)
            hole_area_fake = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            mask = gen_input_mask(
                shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                hole_size=((args.hole_min_w, args.hole_max_w),
                           (args.hole_min_h, args.hole_max_h)),
                hole_area=hole_area_fake,
                max_holes=args.max_holes,
            ).to(gpu)
            fake = torch.zeros((len(x), 1)).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input_cn = torch.cat((x_mask, mask), dim=1)
            output_cn = model_cn(input_cn)
            input_gd_fake = output_cn.detach()  # 输入全局判别器的生成图片
            input_ld_fake = crop(input_gd_fake, hole_area_fake)  # 输入局部判别器的生成图片
            output_fake = model_cd(
                (input_ld_fake.to(gpu), input_gd_fake.to(gpu)))
            loss_fake = bceloss(output_fake, fake)

            # real forward
            hole_area_real = gen_hole_area(size=(args.ld_input_size,
                                                 args.ld_input_size),
                                           mask_size=(x.shape[3], x.shape[2]))
            real = torch.ones((len(x), 1)).to(gpu)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area_real)  # 输入全局判别器的真实图片
            output_real = model_cd(
                (input_ld_real, input_gd_real))  # 输入局部判别器的生成图片
            loss_real = bceloss(output_real, real)

            # reduce
            loss = (loss_fake + loss_real) / 2.

            # backward
            loss.backward()
            cnt_bdivs += 1

            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0
                # optimize
                opt_cd.step()
                # clear grads
                opt_cd.zero_grad()
                # update progbar
                pbar.set_description('phase 2 | train loss: %.5f' % loss.cpu())
                pbar.update()
                # test
                if pbar.n % args.snaperiod_2 == 0:
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(
                            shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                            hole_size=((args.hole_min_w, args.hole_max_w),
                                       (args.hole_min_h, args.hole_max_h)),
                            hole_area=gen_hole_area(
                                (args.ld_input_size, args.ld_input_size),
                                (x.shape[3], x.shape[2])),
                            max_holes=args.max_holes,
                        ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = poisson_blend(x, output, mask)  # 泊松融合
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_2',
                                               'step%d.png' % pbar.n)
                        model_cd_path = os.path.join(
                            args.result_dir, 'phase_2',
                            'model_cd_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cd.module.state_dict(),
                                       model_cd_path)
                        else:
                            torch.save(model_cd.state_dict(), model_cd_path)
                # terminate
                if pbar.n >= args.steps_2:
                    break
    pbar.close()

    # ================================================
    # Training Phase 3
    # ================================================
    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_3)
    pbar.n = 120000
    while pbar.n < args.steps_3:
        for i, x in enumerate(train_loader):

            # forward model_cd
            x = x.to(gpu)
            hole_area_fake = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            mask = gen_input_mask(
                shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                hole_size=((args.hole_min_w, args.hole_max_w),
                           (args.hole_min_h, args.hole_max_h)),
                hole_area=hole_area_fake,
                max_holes=args.max_holes,
            ).to(gpu)

            # fake forward
            fake = torch.zeros((len(x), 1)).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input_cn = torch.cat((x_mask, mask), dim=1)
            output_cn = model_cn(input_cn)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd((input_ld_fake, input_gd_fake))
            loss_cd_fake = bceloss(output_fake, fake)

            # real forward
            hole_area_real = gen_hole_area(size=(args.ld_input_size,
                                                 args.ld_input_size),
                                           mask_size=(x.shape[3], x.shape[2]))
            real = torch.ones((len(x), 1)).to(gpu)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area_real)
            output_real = model_cd((input_ld_real, input_gd_real))
            loss_cd_real = bceloss(output_real, real)

            # reduce
            loss_cd = (loss_cd_fake + loss_cd_real) * alpha / 2.

            # backward model_cd
            loss_cd.backward()

            cnt_bdivs += 1
            if cnt_bdivs >= args.bdivs:
                # optimize
                opt_cd.step()
                # clear grads
                opt_cd.zero_grad()

            # forward model_cn
            loss_cn_1 = completion_network_loss(x, output_cn, mask)
            input_gd_fake = output_cn
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd((input_ld_fake, (input_gd_fake)))
            loss_cn_2 = bceloss(output_fake, real)

            loss_cn_3 = contextual_loss(x, output_cn, mask)
            # reduce
            loss_cn = (loss_cn_1 + alpha * loss_cn_2 + 4e-3 * loss_cn_3) / 2.

            # backward model_cn
            loss_cn.backward()

            my_writer.add_scalar('mse_loss', loss_cn_1,
                                 (90000 + pbar.n) * len(train_loader) + i)

            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0
                # optimize
                opt_cn.step()
                # clear grads
                opt_cn.zero_grad()
                # update progbar
                pbar.set_description(
                    'phase 3 | train loss (cd): %.5f (cn): %.5f' %
                    (loss_cd.cpu(), loss_cn.cpu()))
                pbar.update()
                # test
                if pbar.n % args.snaperiod_3 == 0:
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(
                            shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                            hole_size=((args.hole_min_w, args.hole_max_w),
                                       (args.hole_min_h, args.hole_max_h)),
                            hole_area=gen_hole_area(
                                (args.ld_input_size, args.ld_input_size),
                                (x.shape[3], x.shape[2])),
                            max_holes=args.max_holes,
                        ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = poisson_blend(x, output, mask)
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_3',
                                               'step%d.png' % pbar.n)
                        model_cn_path = os.path.join(
                            args.result_dir, 'phase_3',
                            'model_cn_step%d' % pbar.n)
                        model_cd_path = os.path.join(
                            args.result_dir, 'phase_3',
                            'model_cd_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cn.module.state_dict(),
                                       model_cn_path)
                            torch.save(model_cd.module.state_dict(),
                                       model_cd_path)
                        else:
                            torch.save(model_cn.state_dict(), model_cn_path)
                            torch.save(model_cd.state_dict(), model_cd_path)
                # terminate
                if pbar.n >= args.steps_3:
                    break
    pbar.close()
예제 #8
0
def main():
    args = read_args(default_config="confs/kim_cnn_sst2.json")
    set_seed(args.seed)
    try:
        os.makedirs(args.workspace)
    except:
        pass
    torch.cuda.deterministic = True

    dataset_cls = find_dataset(args.dataset_name)
    training_iter, dev_iter, test_iter = dataset_cls.iters(args.dataset_path, args.vectors_file, args.vectors_dir,
        batch_size=args.batch_size, device=args.device, train=args.train_file, dev=args.dev_file, test=args.test_file)

    args.dataset = training_iter.dataset
    args.words_num = len(training_iter.dataset.TEXT_FIELD.vocab)
    model = mod.SiameseRNNModel(args).to(args.device)
    ckpt_attrs = mod.load_checkpoint(model, args.workspace,
        best=args.load_best_checkpoint) if args.load_last_checkpoint or args.load_best_checkpoint else {}
    offset = ckpt_attrs.get("epoch_idx", -1) + 1
    args.epochs -= offset

    training_pbar = tqdm(total=len(training_iter), position=2)
    training_pbar.set_description("Training")
    dev_pbar = tqdm(total=args.epochs, position=1)
    dev_pbar.set_description("Dev")

    criterion = nn.CrossEntropyLoss()
    kd_criterion = nn.KLDivLoss(reduction="batchmean")
    params = list(filter(lambda x: x.requires_grad, model.parameters()))
    optimizer = Adadelta(params, lr=args.lr, rho=0.95)
    increment_fn = mod.make_checkpoint_incrementer(model, args.workspace, save_last=True,
        best_loss=ckpt_attrs.get("best_dev_loss", 10000))
    non_embedding_params = model.non_embedding_params()

    if args.use_data_parallel:
        model = nn.DataParallel(model)
    if args.eval_test_only:
        test_acc, _ = evaluate(model, test_iter, criterion, export_eval_labels=args.export_eval_labels)
        print(test_acc)
        return
    if args.epochs == 0:
        print("No epochs left from loaded model.", file=sys.stderr)
        return
    for epoch_idx in tqdm(range(args.epochs), position=0):
        training_iter.init_epoch()
        model.train()
        training_pbar.n = 0
        training_pbar.refresh()
        for batch in training_iter:
            training_pbar.update(1)
            optimizer.zero_grad()
            logits = model(batch.sentence1, batch.sentence2)
            kd_logits = torch.stack((batch.logits_0, batch.logits_1, batch.logits_2), 1)
            kd = args.distill_lambda * kd_criterion(F.log_softmax(logits / args.distill_temperature, 1),
                F.softmax(kd_logits / args.distill_temperature, 1))
            loss = args.ce_lambda * criterion(logits, batch.gold_label) + kd
            loss.backward()
            clip_grad_norm_(non_embedding_params, args.clip_grad)
            optimizer.step()
            acc = ((logits.max(1)[1] == batch.gold_label).float().sum() / batch.gold_label.size(0)).item()
            training_pbar.set_postfix(accuracy=f"{acc:.2}")

        model.eval()
        dev_acc, dev_loss = evaluate(model, dev_iter, criterion)
        dev_pbar.update(1)
        dev_pbar.set_postfix(accuracy=f"{dev_acc:.4}")
        is_best_dev = increment_fn(dev_loss, dev_acc=dev_acc, epoch_idx=epoch_idx + offset)

        if is_best_dev:
            dev_pbar.set_postfix(accuracy=f"{dev_acc:.4} (best loss)")
            test_acc, _ = evaluate(model, test_iter, criterion, export_eval_labels=args.export_eval_labels)
    training_pbar.close()
    dev_pbar.close()
    print(f"Test accuracy of the best model: {test_acc:.4f}", file=sys.stderr)
    print(test_acc)
예제 #9
0
    def train_stage_1(self):
        data_set = loader(None, {"seed": 10, "mode": "training"})
        data_set_test = loader(None, {
            "seed": 10,
            "mode": "test"
        }, data_set.index)
        data_set_eval = loader(None, {
            "seed": 10,
            "mode": "eval"
        }, data_set.index)
        data_loader = DataLoader(data_set,
                                 self.batch,
                                 True,
                                 collate_fn=call_back.detection_collate_RPN,
                                 num_workers=0)
        data_loader_test = DataLoader(
            data_set_test,
            self.batch,
            False,
            collate_fn=call_back.detection_collate_RPN,
            num_workers=0,
        )

        # optim = Adadelta(self.RPN.parameters(), lr=lr1, weight_decay=1e-5)
        optim = Adadelta(self.RPN.parameters(), lr=self.lr1, weight_decay=1e-5)

        tool = rpn_tool_d()
        start_time = time.time()
        # print(optim.state_dict())
        for epoch in range(3000):
            runing_losss = 0.0
            cls_loss = 0
            coor_loss = 0

            for data in data_loader:
                y = data[1]
                x = data[0].cuda()

                optim.zero_grad()
                with torch.no_grad():
                    x1, x2, x3, x4 = self.features(x)
                predict_confidence, box_predict = self.RPN(x1, x2, x3, x4)
                cross_entropy, loss_box = tool.get_proposal(
                    predict_confidence, box_predict, y)
                loss_total = cross_entropy + loss_box
                loss_total.backward()
                optim.step()
                runing_losss += loss_total.item()
                cls_loss += cross_entropy.item()
                coor_loss += loss_box.item()
            end_time = time.time()
            # self.vis.line(np.asarray([cls_loss, coor_loss]).reshape(1, 2),
            #               np.asarray([epoch] * 2).reshape(1, 2), win="loss-epoch", update='append',
            #               opts=dict(title='loss', legend=['cls_loss', 'cor_loss']))
            print(
                "epoch:{a}: loss:{b:.4f} spend_time:{c:.4f} cls:{d:.4f} cor{e:.4f} date:{ff}"
                .format(a=epoch,
                        b=runing_losss,
                        c=int(end_time - start_time),
                        d=cls_loss,
                        e=coor_loss,
                        ff=time.asctime()))
            start_time = end_time

            # if self.add_eval:
            #     p = self.RPN_eval(self,data_loader_eval, epoch, eval=True, seed=self.seed)
            self.RPN_eval(data_loader_test, {"epoch": epoch})

            save(self.RPN.module.state_dict(),
                 os.path.join(os.getcwd(),
                              str(epoch) + 'rpn_a1.p'))
            save(self.RPN.module.state_dict(),
                 os.path.join(os.getcwd(),
                              str(epoch) + 'base_a1.p'))

            if epoch % 10 == 0 and epoch > 0:
                adjust_learning_rate(optim, 0.9, epoch, 50, 0.3)
예제 #10
0
class MusCapsTrainer(BaseTrainer):
    def __init__(self, config, logger):
        super(BaseTrainer, self).__init__()
        self.config = config
        self.logger = logger
        self.device = torch.device(self.config.training.device)
        self.patience = self.config.training.patience
        self.lr = self.config.training.lr

        self.load_dataset()
        self.build_model()
        self.build_loss()
        self.build_optimizer()

    def load_dataset(self):
        self.logger.write("Loading dataset")
        dataset_name = self.config.dataset_config.dataset_name
        if dataset_name == "audiocaption":
            train_dataset = AudioCaptionDataset(self.config.dataset_config)
            val_dataset = AudioCaptionDataset(self.config.dataset_config,
                                              "val")
        else:
            raise ValueError(
                "{} dataset is not supported.".format(dataset_name))
        self.vocab = train_dataset.vocab
        self.logger.save_vocab(self.vocab.token_freq)
        OmegaConf.update(self.config, "model_config.vocab_size",
                         self.vocab.size)
        self.train_loader = DataLoader(
            dataset=train_dataset,
            batch_size=self.config.training.batch_size,
            shuffle=self.config.training.shuffle,
            num_workers=self.config.training.num_workers,
            pin_memory=self.config.training.pin_memory,
            collate_fn=custom_collate_fn,
            drop_last=True)
        self.val_loader = DataLoader(
            dataset=val_dataset,
            batch_size=self.config.training.batch_size,
            shuffle=self.config.training.shuffle,
            num_workers=self.config.training.num_workers,
            pin_memory=self.config.training.pin_memory,
            collate_fn=custom_collate_fn)
        self.logger.write("Number of training samples: {}".format(
            train_dataset.__len__()))

    def build_model(self):
        self.logger.write("Building model")
        model_name = self.config.model_config.model_name
        if model_name == "cnn_lstm_caption":
            self.model = CNNLSTMCaption(self.config.model_config,
                                        self.vocab,
                                        self.device,
                                        teacher_forcing=True)
        elif model_name == "cnn_attention_lstm":
            self.model = AttentionModel(self.config.model_config,
                                        self.vocab,
                                        self.device,
                                        teacher_forcing=True)
        else:
            raise ValueError("{} model is not supported.".format(model_name))
        if self.model.audio_encoder.pretrained_version is not None and not self.model.finetune:
            for param in self.model.audio_encoder.feature_extractor.parameters(
            ):
                param.requires_grad = False
        self.model.to(self.device)

    def count_parameters(self):
        """ Count trainable parameters in model. """
        return sum(p.numel() for p in self.model.parameters()
                   if p.requires_grad)

    def build_loss(self):
        self.logger.write("Building loss")
        loss_name = self.config.model_config.loss
        if loss_name == "cross_entropy":
            self.loss = nn.CrossEntropyLoss(ignore_index=self.vocab.PAD_INDEX)
        else:
            raise ValueError("{} loss is not supported.".format(loss_name))
        self.loss = self.loss.to(self.device)

    def build_optimizer(self):
        self.logger.write("Building optimizer")
        optimizer_name = self.config.training.optimizer
        if optimizer_name == "adam":
            self.optimizer = Adam(self.model.parameters(), lr=self.lr)
        elif optimizer_name == "adadelta":
            self.optimizer = Adadelta(self.model.parameters(), lr=self.lr)
        else:
            raise ValueError(
                "{} optimizer is not supported.".format(optimizer_name))

    def train(self):
        if os.path.exists(self.logger.checkpoint_path):
            self.logger.write("Resumed training experiment with id {}".format(
                self.logger.experiment_id))
            self.load_ckp(self.logger.checkpoint_path)
        else:
            self.logger.write("Started training experiment with id {}".format(
                self.logger.experiment_id))
            self.start_epoch = 0

        # Adaptive learning rate
        scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                   mode='min',
                                                   factor=0.5,
                                                   patience=self.patience,
                                                   verbose=True)

        k_patience = 0
        best_val = np.Inf

        for epoch in range(self.start_epoch, self.config.training.epochs):
            epoch_start_time = time.time()

            train_loss = self.train_epoch(self.train_loader,
                                          self.device,
                                          is_training=True)
            val_loss = self.train_epoch_val(self.val_loader,
                                            self.device,
                                            is_training=False)

            # Decrease the learning rate after not improving in the validation set
            scheduler.step(val_loss)

            # check if val loss has been improving during patience period. If not, stop
            is_val_improving = scheduler.is_better(val_loss, best_val)
            if not is_val_improving:
                k_patience += 1
            else:
                k_patience = 0
            if k_patience > self.patience * 2:
                print("Early Stopping")
                break

            best_val = scheduler.best

            epoch_time = time.time() - epoch_start_time
            lr = self.optimizer.param_groups[0]['lr']

            self.logger.update_training_log(epoch + 1, train_loss, val_loss,
                                            epoch_time, lr)

            checkpoint = {
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict()
            }

            # save checkpoint in appropriate path (new or best)
            self.logger.save_checkpoint(state=checkpoint,
                                        is_best=is_val_improving)

    def load_ckp(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.start_epoch = checkpoint['epoch']

    def train_epoch(self, data_loader, device, is_training):
        out_list = []
        target_list = []
        running_loss = 0.0
        n_batches = 0

        if is_training:
            self.model.train()
            if self.model.audio_encoder.pretrained_version is not None:
                for module in self.model.audio_encoder.feature_extractor.modules(
                ):
                    if isinstance(module, nn.BatchNorm2d) or isinstance(
                            module, nn.BatchNorm1d):
                        module.eval()
        else:
            self.model.eval()

        for i, batch in enumerate(data_loader):
            audio, audio_len, x, x_len = batch
            target_list.append(x)
            audio = audio.float().to(device=device)
            x = x.long().to(device=device)
            audio_len.to(device=device)
            out = self.model(audio, audio_len, x, x_len)

            out_list.append(out)

            target = x[:, 1:]  # target excluding sos token
            out = out.transpose(1, 2)
            loss = self.loss(out, target)

            if is_training:
                self.optimizer.zero_grad()
                loss.backward()
                if self.config.training.clip_gradients:
                    clip_grad_norm_(self.model.parameters(), 12)
                self.optimizer.step()

            running_loss += loss.item()

            n_batches += 1

        return running_loss / n_batches

    def train_epoch_val(self, data_loader, device, is_training=False):
        with torch.no_grad():
            loss = self.train_epoch(data_loader, device, is_training=False)
        return loss
예제 #11
0
def main(args):

    # ================================================
    # Preparation
    # ================================================
    args.data_dir = os.path.expanduser(args.data_dir)
    args.result_dir = os.path.expanduser(args.result_dir)
    if args.init_model_cn != None:
        args.init_model_cn = os.path.expanduser(args.init_model_cn)
    if args.init_model_cd != None:
        args.init_model_cd = os.path.expanduser(args.init_model_cd)
    if torch.cuda.is_available() == False:
        raise Exception('At least one gpu must be available.')
    else:
        gpu = torch.device('cuda:0')

    # create result directory (if necessary)
    if os.path.exists(args.result_dir) == False:
        os.makedirs(args.result_dir)
    for s in ['phase_1', 'phase_2', 'phase_3']:
        if os.path.exists(os.path.join(args.result_dir, s)) == False:
            os.makedirs(os.path.join(args.result_dir, s))

    # dataset
    trnsfm = transforms.Compose([
        transforms.Resize(args.cn_input_size),
        transforms.RandomCrop((args.cn_input_size, args.cn_input_size)),
        transforms.ToTensor(),
    ])
    print('loading dataset... (it may take a few minutes)')
    train_dset = ImageDataset(os.path.join(args.data_dir, 'train'), trnsfm)
    test_dset = ImageDataset(os.path.join(args.data_dir, 'test'), trnsfm)
    train_loader = DataLoader(train_dset,
                              batch_size=(args.bsize // args.bdivs),
                              shuffle=True)

    # compute mean pixel value of training dataset
    mpv = 0.
    if args.mpv == None:
        pbar = tqdm(total=len(train_dset.imgpaths),
                    desc='computing mean pixel value for training dataset...')
        for imgpath in train_dset.imgpaths:
            img = Image.open(imgpath)
            x = np.array(img, dtype=np.float32) / 255.
            mpv += x.mean()
            pbar.update()
        mpv /= len(train_dset.imgpaths)
        pbar.close()
    else:
        mpv = args.mpv
    mpv = torch.tensor(mpv).to(gpu)
    alpha = torch.tensor(args.alpha).to(gpu)

    # save training config
    args_dict = vars(args)
    args_dict['mpv'] = float(mpv)
    with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f:
        json.dump(args_dict, f)

    # ================================================
    # Training Phase 1
    # ================================================
    data = load_lua('./glcic/completionnet_places2.t7')
    model_cn = data.model
    if args.optimizer == 'adadelta':
        opt_cn = Adadelta(model_cn.parameters())
    else:
        opt_cn = Adam(model_cn.parameters())
    """
    model_cn = CompletionNetwork()
    if args.init_model_cn != None:
        model_cn.load_state_dict(torch.load(args.init_model_cn, map_location='cpu'))

    if args.data_parallel:
        model_cn = DataParallel(model_cn)

    model_cn = model_cn.to(gpu)

    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_1)
    while pbar.n < args.steps_1:
        for x in train_loader:

            # forward
            x = x.to(gpu)
            msk = gen_input_mask(
                shape=x.shape,
                hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)),
                hole_area=gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])),
                max_holes=args.max_holes,
            ).to(gpu)
            output = model_cn(x - x * msk + mpv * msk)
            loss = completion_network_loss(x, output, msk)

            # backward
            loss.backward()
            cnt_bdivs += 1

            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0
                # optimize
                opt_cn.step()
                # clear grads
                opt_cn.zero_grad()
                # update progbar
                pbar.set_description('phase 1 | train loss: %.5f' % loss.cpu())
                pbar.update()
                # test
                if pbar.n % args.snaperiod_1 == 0:
                    with torch.no_grad():
                        x = sample_random_batch(test_dset, batch_size=args.num_test_completions).to(gpu)
                        msk = gen_input_mask(
                            shape=x.shape,
                            hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)),
                            hole_area=gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])),
                            max_holes=args.max_holes,
                        ).to(gpu)
                        input = x - x * msk + mpv * msk
                        output = model_cn(input)
                        completed = poisson_blend(input, output, msk)
                        imgs = torch.cat((x.cpu(), input.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_1', 'step%d.png' % pbar.n)
                        model_cn_path = os.path.join(args.result_dir, 'phase_1', 'model_cn_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        torch.save(model_cn.state_dict(), model_cn_path)
                # terminate
                if pbar.n >= args.steps_1:
                    break
    pbar.close()

    """
    # ================================================
    # Training Phase 2
    # ================================================
    model_cd = ContextDiscriminator(
        local_input_shape=(3, args.ld_input_size, args.ld_input_size),
        global_input_shape=(3, args.cn_input_size, args.cn_input_size),
    )
    if args.data_parallel:
        model_cd = DataParallel(model_cd)
    if args.init_model_cd != None:
        model_cd.load_state_dict(
            torch.load(args.init_model_cd, map_location='cpu'))
    if args.optimizer == 'adadelta':
        opt_cd = Adadelta(model_cd.parameters())
    else:
        opt_cd = Adam(model_cd.parameters())
    model_cd = model_cd.to(gpu)
    bceloss = BCELoss()

    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_2)
    while pbar.n < args.steps_2:
        for x in train_loader:

            # fake forward
            x = x.to(gpu)
            hole_area_fake = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            msk = gen_input_mask(
                shape=x.shape,
                hole_size=((args.hole_min_w, args.hole_max_w),
                           (args.hole_min_h, args.hole_max_h)),
                hole_area=hole_area_fake,
                max_holes=args.max_holes,
            ).to(gpu)
            fake = torch.zeros((len(x), 1)).to(gpu)
            output_cn = model_cn(x - x * msk + mpv * msk)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd(
                (input_ld_fake.to(gpu), input_gd_fake.to(gpu)))
            loss_fake = bceloss(output_fake, fake)

            # real forward
            hole_area_real = gen_hole_area(size=(args.ld_input_size,
                                                 args.ld_input_size),
                                           mask_size=(x.shape[3], x.shape[2]))
            real = torch.ones((len(x), 1)).to(gpu)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area_real)
            output_real = model_cd((input_ld_real, input_gd_real))
            loss_real = bceloss(output_real, real)

            # reduce
            loss = (loss_fake + loss_real) / 2.

            # backward
            loss.backward()
            cnt_bdivs += 1

            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0
                # optimize
                opt_cd.step()
                # clear grads
                opt_cd.zero_grad()
                # update progbar
                pbar.set_description('phase 2 | train loss: %.5f' % loss.cpu())
                pbar.update()
                # test
                if pbar.n % args.snaperiod_2 == 0:
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        msk = gen_input_mask(
                            shape=x.shape,
                            hole_size=((args.hole_min_w, args.hole_max_w),
                                       (args.hole_min_h, args.hole_max_h)),
                            hole_area=gen_hole_area(
                                (args.ld_input_size, args.ld_input_size),
                                (x.shape[3], x.shape[2])),
                            max_holes=args.max_holes,
                        ).to(gpu)
                        input = x - x * msk + mpv * msk
                        output = model_cn(input)
                        completed = poisson_blend(input, output, msk)
                        imgs = torch.cat(
                            (x.cpu(), input.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_2',
                                               'step%d.png' % pbar.n)
                        model_cd_path = os.path.join(
                            args.result_dir, 'phase_2',
                            'model_cd_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        torch.save(model_cd.state_dict(), model_cd_path)
                # terminate
                if pbar.n >= args.steps_2:
                    break
    pbar.close()

    # ================================================
    # Training Phase 3
    # ================================================
    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_3)
    while pbar.n < args.steps_3:
        for x in train_loader:

            # forward model_cd
            x = x.to(gpu)
            hole_area_fake = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            msk = gen_input_mask(
                shape=x.shape,
                hole_size=((args.hole_min_w, args.hole_max_w),
                           (args.hole_min_h, args.hole_max_h)),
                hole_area=hole_area_fake,
                max_holes=args.max_holes,
            ).to(gpu)

            # fake forward
            fake = torch.zeros((len(x), 1)).to(gpu)
            output_cn = model_cn(x - x * msk + mpv * msk)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd((input_ld_fake, input_gd_fake))
            loss_cd_fake = bceloss(output_fake, fake)

            # real forward
            hole_area_real = gen_hole_area(size=(args.ld_input_size,
                                                 args.ld_input_size),
                                           mask_size=(x.shape[3], x.shape[2]))
            real = torch.ones((len(x), 1)).to(gpu)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area_real)
            output_real = model_cd((input_ld_real, input_gd_real))
            loss_cd_real = bceloss(output_real, real)

            # reduce
            loss_cd = (loss_cd_fake + loss_cd_real) * alpha / 2.

            # backward model_cd
            loss_cd.backward()

            # forward model_cn
            loss_cn_1 = completion_network_loss(x, output_cn, msk)
            input_gd_fake = output_cn
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd((input_ld_fake, (input_gd_fake)))
            loss_cn_2 = bceloss(output_fake, real)

            # reduce
            loss_cn = (loss_cn_1 + alpha * loss_cn_2) / 2.

            # backward model_cn
            loss_cn.backward()
            cnt_bdivs += 1

            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0
                # optimize
                opt_cn.step()
                opt_cn.step()
                # clear grads
                opt_cd.zero_grad()
                opt_cn.zero_grad()
                # update progbar
                pbar.set_description(
                    'phase 3 | train loss (cd): %.5f (cn): %.5f' %
                    (loss_cd.cpu(), loss_cn.cpu()))
                pbar.update()
                # test
                if pbar.n % args.snaperiod_3 == 0:
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        msk = gen_input_mask(
                            shape=x.shape,
                            hole_size=((args.hole_min_w, args.hole_max_w),
                                       (args.hole_min_h, args.hole_max_h)),
                            hole_area=gen_hole_area(
                                (args.ld_input_size, args.ld_input_size),
                                (x.shape[3], x.shape[2])),
                            max_holes=args.max_holes,
                        ).to(gpu)
                        input = x - x * msk + mpv * msk
                        output = model_cn(input)
                        completed = poisson_blend(input, output, msk)
                        imgs = torch.cat(
                            (x.cpu(), input.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_3',
                                               'step%d.png' % pbar.n)
                        model_cn_path = os.path.join(
                            args.result_dir, 'phase_3',
                            'model_cn_step%d' % pbar.n)
                        model_cd_path = os.path.join(
                            args.result_dir, 'phase_3',
                            'model_cd_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        torch.save(model_cn.state_dict(), model_cn_path)
                        torch.save(model_cd.state_dict(), model_cd_path)
                # terminate
                if pbar.n >= args.steps_3:
                    break
    pbar.close()
예제 #12
0
    def train_stage_2(self):

        batch = 240
        lr1 = 0.15

        data_set = loader(os.path.join(os.getcwd(), 'data_2'), {"mode": "training"})
        data_set_test = loader(os.path.join(os.getcwd(), 'data_2'),{"mode": "test"}, data_set.index)
        data_set_eval = loader(os.path.join(os.getcwd(), 'data_2'),{"mode": "eval"}, data_set.index)

        data_loader = DataLoader(data_set, batch, True, collate_fn=call_back.detection_collate_RPN)
        data_loader_test = DataLoader(data_set_test, batch, False, collate_fn=call_back.detection_collate_RPN)
        data_loader_eval = DataLoader(data_set_eval, batch, False, collate_fn=call_back.detection_collate_RPN)

        # optim = Adadelta(self.ROI.parameters(), lr=lr1, weight_decay=1e-5)
        start_time = time.time()
        optim_a = Adadelta([{'params': self.pre.parameters()},
                            {'params': self.ROI.parameters()}], lr=0.15, weight_decay=1e-5)
        cfg.test = False
        count = 0
        for epoch in range(200):
            runing_losss = 0.0
            cls_loss = 0
            coor_loss = 0
            cls_loss2 = 0
            coor_loss2 = 0
            count += 1
            # base_time = RPN_time = ROI_time = nms_time = pre_gt = loss_time = linear_time = 0
            for data in data_loader:
                y = data[1]
                x = data[0].cuda()
                peak = data[2]
                num = data[3]
                optim_a.zero_grad()

                with torch.no_grad():
                    if self.flag >= 2:
                        result = self.base_process(x, y, peak)
                        feat1 = result['feat_8']
                        feat2 = result['feat_16']
                        feat3 = result['feat_32']
                        feat4 = result['feat_64']
                        label = result['label']
                        loss_box = result['loss_box']
                        cross_entropy = result['cross_entropy']

                cls_score = self.pre(feat1, feat2, feat3, feat4)
                cls_score = self.ROI(cls_score)

                cross_entropy2 = self.tool2.cal_loss2(cls_score, label)

                loss_total = cross_entropy2
                loss_total.backward()
                optim_a.step()
                runing_losss += loss_total.item()
                cls_loss2 += cross_entropy2.item()
                cls_loss += cross_entropy.item()
                coor_loss += loss_box.item()
            end_time = time.time()
            torch.cuda.empty_cache()
            print(
                "epoch:{a} time:{ff}: loss:{b:.4f} cls:{d:.4f} cor{e:.4f} cls2:{f:.4f} cor2:{g:.4f} date:{fff}".format(
                    a=epoch,
                    b=runing_losss,
                    d=cls_loss,
                    e=coor_loss,
                    f=cls_loss2,
                    g=coor_loss2, ff=int(end_time - start_time),
                    fff=time.asctime()))
            # if epoch % 10 == 0:
            #     adjust_learning_rate(optim, 0.9, epoch, 50, lr1)
            p = None

            # if epoch % 2 == 0:
            #     print("test result")
            # save(self.RPN.module.state_dict(),
            #      os.path.join(os.getcwd(), str(epoch) + 'rpn_a2.p'))
            # save(self.RPN.module.state_dict(),
            #      os.path.join(os.getcwd(), str(epoch) + 'base_a2.p'))
            start_time = end_time
        all_data = []
        all_label = []
        for data in data_loader:
            y = data[1]
            x = data[0].cuda()
            num = data[3]
            peak = data[2]
            with torch.no_grad():
                if self.flag >= 2:
                    result = self.base_process_2(x, y, peak)
                    data_ = result['x']
                    label = result['label']
                    loss_box = result['loss_box']
                    cross_entropy = result['cross_entropy']
                    all_data.extend(data_.cpu())
                    all_label.extend(label.cpu())
        for data in data_loader_eval:
            y = data[1]
            x = data[0].cuda()
            num = data[3]
            peak = data[2]
            with torch.no_grad():
                if self.flag >= 2:
                    result = self.base_process_2(x, y, peak)
                    data_ = result['x']
                    label = result['label']
                    loss_box = result['loss_box']
                    cross_entropy = result['cross_entropy']
                    all_data.extend(data_.cpu())
                    all_label.extend(label.cpu())
        for data in data_loader_test:
            y = data[1]
            x = data[0].cuda()
            num = data[3]
            peak = data[2]
            with torch.no_grad():
                if self.flag >= 2:
                    result = self.base_process_2(x, y, peak)
                    data_ = result['x']
                    label = result['label']
                    loss_box = result['loss_box']
                    cross_entropy = result['cross_entropy']
                    all_data.extend(data_.cpu())
                    all_label.extend(label.cpu())

        all_data = torch.stack(all_data, 0).numpy()
        all_label = torch.LongTensor(all_label).numpy()
        from imblearn.over_sampling import SMOTE
        fun = SMOTE()
        all_data, all_label = fun.fit_resample(all_data, all_label)
        total = len(all_label)
        training_label = all_label[:int(0.7 * total)]
        training_data = all_data[:int(0.7 * total)]

        test_label = all_label[-int(0.2 * total):]
        test_data = all_data[-int(0.2 * total):]
        count = 0
        self.ROI = roi().cuda()
        self.ROI = DataParallel(self.ROI, device_ids=[0])
        self.ROI.apply(weights_init)

        optim_b = Adadelta(self.ROI.parameters(), lr=0.15, weight_decay=1e-5)
        for epoch in range(1200):
            runing_losss = 0.0
            cls_loss = 0
            coor_loss = 0
            cls_loss2 = 0
            coor_loss2 = 0
            count += 1
            optim_b.zero_grad()
            optim_a.zero_grad()

            # base_time = RPN_time = ROI_time = nms_time = pre_gt = loss_time = linear_time = 0
            for j in range(int(len(training_label) / 240)):
                data_ = torch.Tensor(training_data[j * 240:j * 240 + 240]).view(240, 1024, 15).cuda()
                label_ = torch.LongTensor(training_label[j * 240:j * 240 + 240]).cuda()
                optim_b.zero_grad()

                cls_score = self.ROI(data_)
                cross_entropy2 = self.tool2.cal_loss2(cls_score, label_)

                loss_total = cross_entropy2
                loss_total.backward()
                optim_b.step()
                runing_losss += loss_total.item()
                cls_loss2 += cross_entropy2.item()
                cls_loss += cross_entropy.item()
                coor_loss += loss_box.item()
            end_time = time.time()
            torch.cuda.empty_cache()
            print(
                "epoch:{a} time:{ff}: loss:{b:.4f} cls:{d:.4f} cor{e:.4f} cls2:{f:.4f} cor2:{g:.4f} date:{fff}".format(
                    a=epoch,
                    b=runing_losss,
                    d=cls_loss,
                    e=coor_loss,
                    f=cls_loss2,
                    g=coor_loss2, ff=int(end_time - start_time),
                    fff=time.asctime()))
            if epoch % 10 == 0 and epoch > 0:
                adjust_learning_rate(optim_b, 0.9, epoch, 50, 0.3)

            p = None
            self.eval_(test_data, test_label)
            # self.ROI_eval(data_loader_eval, {"epoch": epoch})

            start_time = end_time
        print('finish')
예제 #13
0
def train(conf, args=None):
    pdata = PoemData()
    pdata.read_data(conf)
    pdata.get_vocab()
    model = MyPoetryModel(pdata.vocab_size, conf.embedding_dim,
                          conf.hidden_dim)

    train_data = pdata.train_data
    test_data = pdata.test_data

    train_data = torch.from_numpy(np.array(train_data['pad_words']))
    dev_data = torch.from_numpy(np.array(test_data['pad_words']))

    dataloader = DataLoader(train_data,
                            batch_size=conf.batch_size,
                            shuffle=True,
                            num_workers=conf.num_workers)
    devloader = DataLoader(dev_data,
                           batch_size=conf.batch_size,
                           shuffle=False,
                           num_workers=conf.num_workers)
    if args.optim == "Adadelta":
        optimizer = Adadelta(model.parameters(), lr=conf.learning_rate)
        print("adadelta")
    elif args.optim == "SGD":
        optimizer = SGD(model.parameters(),
                        lr=conf.learning_rate,
                        momentum=0.8,
                        nesterov=True)
        print("SGD")
    elif args.optim == "Adagrad":
        optimizer = Adagrad(model.parameters())
        print("Adagrad")
    else:
        optimizer = Adam(model.parameters(), lr=conf.learning_rate)
        print("default: Adam")

    criterion = nn.CrossEntropyLoss()
    loss_meter = meter.AverageValueMeter()

    if conf.load_best_model:
        model.load_state_dict(torch.load(conf.best_model_path))
        print("loading_best_model from {0}".format(conf.best_model_path))
    if conf.use_gpu:
        model.cuda()
        criterion.cuda()
    step = 0
    bestppl = 1e9
    early_stop_controller = 0
    for epoch in range(conf.n_epochs):
        losses = []
        loss_meter.reset()
        model.train()
        for i, data in enumerate(dataloader):
            data = data.long().transpose(1, 0).contiguous()
            if conf.use_gpu:
                #print("Cuda")
                data = data.cuda()
            input, target = data[:-1, :], data[1:, :]
            optimizer.zero_grad()
            output, _ = model(input)
            loss = criterion(output, target.contiguous().view(-1))
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
            loss_meter.add(loss.item())
            step += 1
            if step % 100 == 0:
                print("epoch_%d_step_%d_loss:%0.4f" %
                      (epoch + 1, step, loss.item()))
        train_loss = float(loss_meter.value()[0])

        model.eval()
        for i, data in enumerate(devloader):
            data = data.long().transpose(1, 0).contiguous()
            if conf.use_gpu:
                data = data.cuda()
            input, target = data[:-1, :], data[1:, :]
            output, _ = model(input)
            loss = criterion(output, target.view(-1))
            loss_meter.add(loss.item())
        ppl = math.exp(loss_meter.value()[0])
        print("epoch_%d_loss:%0.4f , ppl:%0.4f" % (epoch + 1, train_loss, ppl))

        if epoch % conf.save_every == 0:
            torch.save(model.state_dict(),
                       "{0}_{1}".format(conf.model_prefix, epoch))

            fout = open("{0}out_{1}".format(conf.out_path, epoch),
                        'w',
                        encoding='utf-8')
            for word in list('日红山夜湖海月'):
                gen_poetry = generate_poet(model, word, pdata.vocab, conf)
                fout.write("".join(gen_poetry) + '\n\n')
            fout.close()
        if ppl < bestppl:
            bestppl = ppl
            early_stop_controller = 0
            torch.save(model.state_dict(), "{0}".format(conf.best_model_path))
        else:
            early_stop_controller += 1
        if early_stop_controller > 10:
            print("early stop.")
            break
예제 #14
0
    def train(self):
        os.environ["CUDA_VISIBLE_DEVICES"] = '2'
        device_ids = [0]
        self.classifier = classifier()
        get_paprams(self.classifier)
        get_paprams(self.classifier.base)
        # data_set_eval = my_dataset(eval=True)
        # data_set = my_dataset_10s()
        # data_set_test = my_dataset_10s()
        data_set = my_dataset_10s_smote()
        data_set_test = my_dataset_10s_smote(test=True,
                                             all_data=data_set.all_data,
                                             all_label=data_set.all_label,
                                             index_=data_set.index)
        # data_set_eval = my_dataset_10s(eval=True)
        # data_set_combine = my_dataset(combine=True)
        batch = 300
        # totoal_epoch = 2000
        # print('batch:{}'.format(batch))
        # self.evaluation = evaluation
        data_loader = DataLoader(data_set,
                                 batch,
                                 shuffle=True,
                                 collate_fn=detection_collate)
        data_loader_test = DataLoader(data_set_test,
                                      batch,
                                      False,
                                      collate_fn=detection_collate)
        # data_loader_eval = DataLoader(data_set_eval, batch, False, collate_fn=detection_collate)
        self.classifier = self.classifier.cuda()
        self.classifier = DataParallel(self.classifier, device_ids=device_ids)
        optim = Adadelta(self.classifier.parameters(),
                         0.1,
                         0.9,
                         weight_decay=1e-5)

        self.cretion = smooth_focal_weight()

        # data_loader_combine = DataLoader(data_set_combine, 225, False, collate_fn=detection_collate)
        self.classifier.apply(weights_init)
        start_time = time.time()
        count = 0
        epoch = -1
        while 1:
            epoch += 1
            runing_losss = [0] * 5
            for data in data_loader:
                loss = [0] * 5
                y = data[1].cuda()
                x = data[0].cuda()
                optim.zero_grad()

                weight = torch.Tensor([0.5, 2, 0.5, 2]).cuda()

                predict = self.classifier(x)

                for i in range(5):
                    loss[i] = self.cretion(predict[i], y, weight)
                tmp = sum(loss)

                tmp.backward()
                # loss5.backward()
                optim.step()
                for i in range(5):
                    runing_losss[i] += (tmp.item())

                count += 1
                # torch.cuda.empty_cache()
            end_time = time.time()
            print("epoch:{a}: loss:{b} spend_time:{c} time:{d}".format(
                a=epoch,
                b=sum(runing_losss),
                c=int(end_time - start_time),
                d=time.asctime()))
            start_time = end_time

            save(self.classifier.module.base.state_dict(),
                 str(epoch) + 'base_c1.p')
            save(self.classifier.module.state_dict(), str(epoch) + 'base_c1.p')

            self.classifier.eval()

            self.evaluation(self.classifier, data_loader_test, epoch)
            # self.evaluation(self.classifier, data_loader, epoch)

            self.classifier.train()
            if epoch % 10 == 0:
                adjust_learning_rate(optim, 0.9, epoch, totoal_epoch, 0.1)
예제 #15
0
def main(args):
    # ================================================
    # Preparation
    # ================================================
    if not torch.cuda.is_available():
        raise Exception('At least one gpu must be available.')
    gpu = torch.device('cuda:0')

    # create result directory (if necessary)
    if not os.path.exists(args.result_dir):
        os.makedirs(args.result_dir)
    for phase in ['phase_1', 'phase_2', 'phase_3']:
        if not os.path.exists(os.path.join(args.result_dir, phase)):
            os.makedirs(os.path.join(args.result_dir, phase))

    # load dataset
    trnsfm = transforms.Compose([
        transforms.Resize(args.cn_input_size),
        transforms.RandomCrop((args.cn_input_size, args.cn_input_size)),
        transforms.ToTensor(),
    ])
    print('loading dataset... (it may take a few minutes)')
    train_dset = ImageDataset(os.path.join(args.data_dir, 'train'),
                              trnsfm,
                              recursive_search=args.recursive_search)
    test_dset = ImageDataset(os.path.join(args.data_dir, 'test'),
                             trnsfm,
                             recursive_search=args.recursive_search)
    train_loader = DataLoader(train_dset,
                              batch_size=(args.bsize // args.bdivs),
                              shuffle=True)

    # compute mpv (mean pixel value) of training dataset
    if args.mpv is None:
        mpv = np.zeros(shape=(1, ))
        pbar = tqdm(total=len(train_dset.imgpaths),
                    desc='computing mean pixel value of training dataset...')
        for imgpath in train_dset.imgpaths:
            img = Image.open(imgpath)
            x = np.array(img) / 255.
            mpv += x.mean(axis=(0, 1))
            pbar.update()
        mpv /= len(train_dset.imgpaths)
        pbar.close()
    else:
        mpv = np.array(args.mpv)

    # save training config
    mpv_json = []
    for i in range(1):
        mpv_json.append(float(mpv[i]))
    args_dict = vars(args)
    # args_dict['mpv'] = mpv_json
    with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f:
        json.dump(args_dict, f)

    # make mpv & alpha tensors
    mpv = torch.tensor(mpv.reshape(1, 1, 1, 1), dtype=torch.float32).to(gpu)
    alpha = torch.tensor(args.alpha, dtype=torch.float32).to(gpu)

    # ================================================
    # Training Phase 1
    # ================================================
    # load completion network
    model_cn = CompletionNetwork()
    if args.init_model_cn is not None:
        model_cn.load_state_dict(
            torch.load(args.init_model_cn, map_location='cpu'))
    if args.data_parallel:
        model_cn = DataParallel(model_cn)
    model_cn = model_cn.to(gpu)
    opt_cn = Adadelta(model_cn.parameters())

    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_1)
    while pbar.n < args.steps_1:
        for x in train_loader:
            # forward
            x = x.to(gpu)
            mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2],
                                         x.shape[3]), ).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input = torch.cat((x_mask, mask), dim=1)
            output = model_cn(input)
            loss = completion_network_loss(x, output, mask)

            # backward
            loss.backward()
            cnt_bdivs += 1
            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0

                # optimize
                opt_cn.step()
                opt_cn.zero_grad()
                pbar.set_description('phase 1 | train loss: %.5f' % loss.cpu())
                pbar.update()

                # test
                if pbar.n % args.snaperiod_1 == 0:
                    model_cn.eval()
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2],
                                                     x.shape[3]), ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = rejoiner(x_mask, output, mask)
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_1',
                                               'step%d.png' % pbar.n)
                        model_cn_path = os.path.join(
                            args.result_dir, 'phase_1',
                            'model_cn_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cn.module.state_dict(),
                                       model_cn_path)
                        else:
                            torch.save(model_cn.state_dict(), model_cn_path)
                    model_cn.train()
                if pbar.n >= args.steps_1:
                    break
    pbar.close()

    # ================================================
    # Training Phase 2
    # ================================================
    # load context discriminator
    model_cd = ContextDiscriminator(
        local_input_shape=(1, args.ld_input_size, args.ld_input_size),
        global_input_shape=(1, args.cn_input_size, args.cn_input_size),
    )
    if args.init_model_cd is not None:
        model_cd.load_state_dict(
            torch.load(args.init_model_cd, map_location='cpu'))
    if args.data_parallel:
        model_cd = DataParallel(model_cd)
    model_cd = model_cd.to(gpu)
    opt_cd = Adadelta(model_cd.parameters(), lr=0.1)
    bceloss = BCELoss()

    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_2)
    while pbar.n < args.steps_2:
        for x in train_loader:
            # fake forward
            x = x.to(gpu)
            hole_area_fake = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2],
                                         x.shape[3]), ).to(gpu)
            fake = torch.zeros((len(x), 1)).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input_cn = torch.cat((x_mask, mask), dim=1)
            output_cn = model_cn(input_cn)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd(
                (input_ld_fake.to(gpu), input_gd_fake.to(gpu)))
            loss_fake = bceloss(output_fake, fake)

            # real forward
            hole_area_real = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            real = torch.ones((len(x), 1)).to(gpu)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area_real)
            output_real = model_cd((input_ld_real, input_gd_real))
            loss_real = bceloss(output_real, real)

            # reduce
            loss = (loss_fake + loss_real) / 2.

            # backward
            loss.backward()
            cnt_bdivs += 1
            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0

                # optimize
                opt_cd.step()
                opt_cd.zero_grad()
                pbar.set_description('phase 2 | train loss: %.5f' % loss.cpu())
                pbar.update()

                # test
                if pbar.n % args.snaperiod_2 == 0:
                    model_cn.eval()
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2],
                                                     x.shape[3]), ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = rejoiner(x_mask, output, mask)
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_2',
                                               'step%d.png' % pbar.n)
                        model_cd_path = os.path.join(
                            args.result_dir, 'phase_2',
                            'model_cd_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cd.module.state_dict(),
                                       model_cd_path)
                        else:
                            torch.save(model_cd.state_dict(), model_cd_path)
                    model_cn.train()
                if pbar.n >= args.steps_2:
                    break
    pbar.close()

    # ================================================
    # Training Phase 3
    # ================================================
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_3)
    while pbar.n < args.steps_3:
        for x in train_loader:
            # forward model_cd
            x = x.to(gpu)
            hole_area_fake = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2],
                                         x.shape[3]), ).to(gpu)

            # fake forward
            fake = torch.zeros((len(x), 1)).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input_cn = torch.cat((x_mask, mask), dim=1)
            output_cn = model_cn(input_cn)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd((input_ld_fake, input_gd_fake))
            loss_cd_fake = bceloss(output_fake, fake)

            # real forward
            hole_area_real = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            real = torch.ones((len(x), 1)).to(gpu)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area_real)
            output_real = model_cd((input_ld_real, input_gd_real))
            loss_cd_real = bceloss(output_real, real)

            # reduce
            loss_cd = (loss_cd_fake + loss_cd_real) * alpha / 2.

            # backward model_cd
            loss_cd.backward()
            cnt_bdivs += 1
            if cnt_bdivs >= args.bdivs:
                # optimize
                opt_cd.step()
                opt_cd.zero_grad()

            # forward model_cn
            loss_cn_1 = completion_network_loss(x, output_cn, mask)
            input_gd_fake = output_cn
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd((input_ld_fake, (input_gd_fake)))
            loss_cn_2 = bceloss(output_fake, real)

            # reduce
            loss_cn = (loss_cn_1 + alpha * loss_cn_2) / 2.

            # backward model_cn
            loss_cn.backward()
            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0

                # optimize
                opt_cn.step()
                opt_cn.zero_grad()
                pbar.set_description(
                    'phase 3 | train loss (cd): %.5f (cn): %.5f' %
                    (loss_cd.cpu(), loss_cn.cpu()))
                pbar.update()

                # test
                if pbar.n % args.snaperiod_3 == 0:
                    model_cn.eval()
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2],
                                                     x.shape[3]), ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = rejoiner(x_mask, output, mask)
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_3',
                                               'step%d.png' % pbar.n)
                        model_cn_path = os.path.join(
                            args.result_dir, 'phase_3',
                            'model_cn_step%d' % pbar.n)
                        model_cd_path = os.path.join(
                            args.result_dir, 'phase_3',
                            'model_cd_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cn.module.state_dict(),
                                       model_cn_path)
                            torch.save(model_cd.module.state_dict(),
                                       model_cd_path)
                        else:
                            torch.save(model_cn.state_dict(), model_cn_path)
                            torch.save(model_cd.state_dict(), model_cd_path)
                    model_cn.train()
                if pbar.n >= args.steps_3:
                    break
    pbar.close()
예제 #16
0
def main():
    args = read_args(default_config="confs/kim_cnn_sst2.json")
    set_seed(args.seed)
    try:
        os.makedirs(args.workspace)
    except:
        pass
    torch.cuda.deterministic = True

    dataset_cls = find_dataset(args.dataset_name)
    training_iter, dev_iter, test_iter = dataset_cls.iters(
        args.dataset_path,
        args.vectors_file,
        args.vectors_dir,
        batch_size=args.batch_size,
        device=args.device,
        train=args.train_file,
        dev=args.dev_file,
        test=args.test_file)

    args.dataset = training_iter.dataset
    args.words_num = len(training_iter.dataset.TEXT_FIELD.vocab)
    model = mod.SiameseRNNModel(args).to(args.device)

    sd = torch.load('sst.pt')['state_dict']
    del sd['static_embed.weight']
    del sd['non_static_embed.weight']
    del sd['fc1.weight']
    del sd['fc1.bias']
    del sd['fc2.weight']
    del sd['fc2.bias']
    model.load_state_dict(sd, strict=False)
    mod.init_embedding(model, args)
    # embs, field_src  = torch.load('embs_tmp.pt')
    # field_mappings = list_field_mappings(dataset_cls.TEXT_FIELD, field_src)
    # replace_embeds(model.non_static_embed, embs, field_mappings)
    model.to(args.device)

    ckpt_attrs = mod.load_checkpoint(
        model, args.workspace, best=args.load_best_checkpoint
    ) if args.load_last_checkpoint or args.load_best_checkpoint else {}
    torch.save((model.non_static_embed, dataset_cls.TEXT_FIELD.vocab),
               'qqp-embs.pt')
    return
    offset = ckpt_attrs.get("epoch_idx", -1) + 1
    args.epochs -= offset

    training_pbar = tqdm(total=len(training_iter), position=2)
    training_pbar.set_description("Training")
    dev_pbar = tqdm(total=args.epochs, position=1)
    dev_pbar.set_description("Dev")

    criterion = nn.CrossEntropyLoss()
    kd_criterion = nn.MSELoss()  # KLDivLoss(reduction="batchmean")
    filter_params = [(n, p) for n, p in model.named_parameters()
                     if p.requires_grad and 'fc' in n]
    params = list(map(lambda x: x[1], filter_params))
    # print([x[0] for x in filter_params])
    optimizer = Adadelta(params, lr=args.lr, rho=0.95)
    #optimizer = Adam(params, lr=args.lr)
    increment_fn = mod.make_checkpoint_incrementer(model,
                                                   args.workspace,
                                                   save_last=True,
                                                   best_loss=ckpt_attrs.get(
                                                       "best_dev_loss", 10000))
    non_embedding_params = model.non_embedding_params()

    if args.use_data_parallel:
        model = nn.DataParallel(model)
    if args.eval_test_only:
        test_acc, _ = evaluate(model,
                               test_iter,
                               criterion,
                               export_eval_labels=args.export_eval_labels)
        print(test_acc)
        return
    if args.epochs == 0:
        print("No epochs left from loaded model.", file=sys.stderr)
        return
    for epoch_idx in tqdm(range(args.epochs), position=0):
        training_iter.init_epoch()
        model.train()
        training_pbar.n = 0
        training_pbar.refresh()
        for batch in training_iter:
            training_pbar.update(1)
            optimizer.zero_grad()
            logits = model(batch.question1, batch.question2)
            # kd_logits = torch.stack((batch.logits_0, batch.logits_1), 1)
            #kd = args.distill_lambda * kd_criterion(F.log_softmax(logits / args.distill_temperature, 1),
            #    F.softmax(kd_logits / args.distill_temperature, 1))
            # kd = args.distill_lambda * kd_criterion(logits, kd_logits)
            loss = criterion(logits, batch.is_duplicate)
            loss.backward()
            clip_grad_norm_(non_embedding_params, args.clip_grad)
            optimizer.step()
            acc = ((logits.max(1)[1] == batch.is_duplicate).float().sum() /
                   batch.is_duplicate.size(0)).item()
            training_pbar.set_postfix(accuracy=f"{acc:.2}")

        model.eval()
        dev_acc, dev_loss = evaluate(model, dev_iter, criterion)
        dev_pbar.update(1)
        dev_pbar.set_postfix(accuracy=f"{dev_acc:.4}")
        is_best_dev = increment_fn(dev_loss,
                                   dev_acc=dev_acc,
                                   epoch_idx=epoch_idx + offset)

        if is_best_dev:
            dev_pbar.set_postfix(accuracy=f"{dev_acc:.4} (best loss)")
            # test_acc, _ = evaluate(model, test_iter, criterion, export_eval_labels=args.export_eval_labels)
    training_pbar.close()
    dev_pbar.close()
    print(f"Test accuracy of the best model: {test_acc:.4f}", file=sys.stderr)
    print(test_acc)
예제 #17
0
def main(args):

    # ================================================
    # Preparation
    # ================================================
    args.data_dir = os.path.expanduser(args.data_dir)
    args.result_dir = os.path.expanduser(args.result_dir)

    if torch.cuda.is_available() == False:
        raise Exception('At least one gpu must be available.')
    if args.num_gpus == 1:
        # train models in a single gpu
        gpu_cn = torch.device('cuda:0')
        gpu_cd = gpu_cn
    else:
        # train models in different two gpus
        gpu_cn = torch.device('cuda:0')
        gpu_cd = torch.device('cuda:1')

    # create result directory (if necessary)
    if os.path.exists(args.result_dir) == False:
        os.makedirs(args.result_dir)
    for s in ['phase_1', 'phase_2', 'phase_3']:
        if os.path.exists(os.path.join(args.result_dir, s)) == False:
            os.makedirs(os.path.join(args.result_dir, s))

    # dataset
    trnsfm = transforms.Compose([
        transforms.Resize(args.cn_input_size),
        transforms.RandomCrop((args.cn_input_size, args.cn_input_size)),
        transforms.ToTensor(),
    ])
    print('loading dataset... (it may take a few minutes)')
    train_dset = ImageDataset(os.path.join(args.data_dir, 'train'), trnsfm)
    test_dset = ImageDataset(os.path.join(args.data_dir, 'test'), trnsfm)
    train_loader = DataLoader(train_dset, batch_size=args.bsize, shuffle=True)

    # compute the mean pixel value of train dataset
    mean_pv = 0.
    imgpaths = train_dset.imgpaths[:min(args.max_mpv_samples, len(train_dset))]
    if args.comp_mpv:
        pbar = tqdm(total=len(imgpaths), desc='computing the mean pixel value')
        for imgpath in imgpaths:
            img = Image.open(imgpath)
            x = np.array(img, dtype=np.float32) / 255.
            mean_pv += x.mean()
            pbar.update()
        mean_pv /= len(imgpaths)
        pbar.close()
    mpv = torch.tensor(mean_pv).to(gpu_cn)

    # save training config
    args_dict = vars(args)
    args_dict['mean_pv'] = mean_pv
    with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f:
        json.dump(args_dict, f)

    # ================================================
    # Training Phase 1
    # ================================================
    # model & optimizer
    model_cn = CompletionNetwork()
    model_cn = model_cn.to(gpu_cn)
    if args.optimizer == 'adadelta':
        opt_cn = Adadelta(model_cn.parameters())
    else:
        opt_cn = Adam(model_cn.parameters())

    # training
    pbar = tqdm(total=args.steps_1)
    while pbar.n < args.steps_1:
        for x in train_loader:

            opt_cn.zero_grad()

            # generate hole area
            hole_area = gen_hole_area(
                size=(args.ld_input_size, args.ld_input_size),
                mask_size=(x.shape[3], x.shape[2]),
            )

            # create mask
            msk = gen_input_mask(
                shape=x.shape,
                hole_size=(
                    (args.hole_min_w, args.hole_max_w),
                    (args.hole_min_h, args.hole_max_h),
                ),
                hole_area=hole_area,
                max_holes=args.max_holes,
            )

            # merge x, mask, and mpv
            msg = 'phase 1 |'
            x = x.to(gpu_cn)
            msk = msk.to(gpu_cn)
            input = x - x * msk + mpv * msk
            output = model_cn(input)

            # optimize
            loss = completion_network_loss(x, output, msk)
            loss.backward()
            opt_cn.step()

            msg += ' train loss: %.5f' % loss.cpu()
            pbar.set_description(msg)
            pbar.update()

            # test
            if pbar.n % args.snaperiod_1 == 0:
                with torch.no_grad():

                    x = sample_random_batch(test_dset, batch_size=args.bsize)
                    x = x.to(gpu_cn)
                    input = x - x * msk + mpv * msk
                    output = model_cn(input)
                    completed = poisson_blend(input, output, msk)
                    imgs = torch.cat((input.cpu(), completed.cpu()), dim=0)
                    save_image(imgs,
                               os.path.join(args.result_dir, 'phase_1',
                                            'step%d.png' % pbar.n),
                               nrow=len(x))
                    torch.save(
                        model_cn.state_dict(),
                        os.path.join(args.result_dir, 'phase_1',
                                     'model_cn_step%d' % pbar.n))

            if pbar.n >= args.steps_1:
                break
    pbar.close()

    # ================================================
    # Training Phase 2
    # ================================================
    # model, optimizer & criterion
    model_cd = ContextDiscriminator(
        local_input_shape=(3, args.ld_input_size, args.ld_input_size),
        global_input_shape=(3, args.cn_input_size, args.cn_input_size),
    )
    model_cd = model_cd.to(gpu_cd)
    if args.optimizer == 'adadelta':
        opt_cd = Adadelta(model_cd.parameters())
    else:
        opt_cd = Adam(model_cd.parameters())
    criterion_cd = BCELoss()

    # training
    pbar = tqdm(total=args.steps_2)
    while pbar.n < args.steps_2:
        for x in train_loader:

            x = x.to(gpu_cn)
            opt_cd.zero_grad()

            # ================================================
            # fake
            # ================================================
            hole_area = gen_hole_area(
                size=(args.ld_input_size, args.ld_input_size),
                mask_size=(x.shape[3], x.shape[2]),
            )

            # create mask
            msk = gen_input_mask(
                shape=x.shape,
                hole_size=(
                    (args.hole_min_w, args.hole_max_w),
                    (args.hole_min_h, args.hole_max_h),
                ),
                hole_area=hole_area,
                max_holes=args.max_holes,
            )

            fake = torch.zeros((len(x), 1)).to(gpu_cd)
            msk = msk.to(gpu_cn)
            input_cn = x - x * msk + mpv * msk
            output_cn = model_cn(input_cn)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area)
            input_fake = (input_ld_fake.to(gpu_cd), input_gd_fake.to(gpu_cd))
            output_fake = model_cd(input_fake)
            loss_fake = criterion_cd(output_fake, fake)

            # ================================================
            # real
            # ================================================
            hole_area = gen_hole_area(
                size=(args.ld_input_size, args.ld_input_size),
                mask_size=(x.shape[3], x.shape[2]),
            )

            real = torch.ones((len(x), 1)).to(gpu_cd)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area)
            input_real = (input_ld_real.to(gpu_cd), input_gd_real.to(gpu_cd))
            output_real = model_cd(input_real)
            loss_real = criterion_cd(output_real, real)

            # ================================================
            # optimize
            # ================================================
            loss = (loss_fake + loss_real) / 2.
            loss.backward()
            opt_cd.step()

            msg = 'phase 2 |'
            msg += ' train loss: %.5f' % loss.cpu()
            pbar.set_description(msg)
            pbar.update()

            # test
            if pbar.n % args.snaperiod_2 == 0:
                with torch.no_grad():

                    x = sample_random_batch(test_dset, batch_size=args.bsize)
                    x = x.to(gpu_cn)
                    input = x - x * msk + mpv * msk
                    output = model_cn(input)
                    completed = poisson_blend(input, output, msk)
                    imgs = torch.cat((input.cpu(), completed.cpu()), dim=0)
                    save_image(imgs,
                               os.path.join(args.result_dir, 'phase_2',
                                            'step%d.png' % pbar.n),
                               nrow=len(x))
                    torch.save(
                        model_cd.state_dict(),
                        os.path.join(args.result_dir, 'phase_2',
                                     'model_cd_step%d' % pbar.n))

            if pbar.n >= args.steps_2:
                break
    pbar.close()

    # ================================================
    # Training Phase 3
    # ================================================
    # training
    alpha = torch.tensor(args.alpha).to(gpu_cd)
    pbar = tqdm(total=args.steps_3)
    while pbar.n < args.steps_3:
        for x in train_loader:

            x = x.to(gpu_cn)

            # ================================================
            # train model_cd
            # ================================================
            opt_cd.zero_grad()

            # fake
            hole_area = gen_hole_area(
                size=(args.ld_input_size, args.ld_input_size),
                mask_size=(x.shape[3], x.shape[2]),
            )

            # create mask
            msk = gen_input_mask(
                shape=x.shape,
                hole_size=(
                    (args.hole_min_w, args.hole_max_w),
                    (args.hole_min_h, args.hole_max_h),
                ),
                hole_area=hole_area,
                max_holes=args.max_holes,
            )

            fake = torch.zeros((len(x), 1)).to(gpu_cd)
            msk = msk.to(gpu_cn)
            input_cn = x - x * msk + mpv * msk
            output_cn = model_cn(input_cn)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area)
            input_fake = (input_ld_fake.to(gpu_cd), input_gd_fake.to(gpu_cd))
            output_fake = model_cd(input_fake)
            loss_cd_1 = criterion_cd(output_fake, fake)

            # real
            hole_area = gen_hole_area(
                size=(args.ld_input_size, args.ld_input_size),
                mask_size=(x.shape[3], x.shape[2]),
            )

            real = torch.ones((len(x), 1)).to(gpu_cd)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area)
            input_real = (input_ld_real.to(gpu_cd), input_gd_real.to(gpu_cd))
            output_real = model_cd(input_real)
            loss_cd_2 = criterion_cd(output_real, real)

            # optimize
            loss_cd = (loss_cd_1 + loss_cd_2) * alpha / 2.
            loss_cd.backward()
            opt_cd.step()

            # ================================================
            # train model_cn
            # ================================================
            opt_cn.zero_grad()

            loss_cn_1 = completion_network_loss(x, output_cn, msk).to(gpu_cd)
            input_gd_fake = output_cn
            input_ld_fake = crop(input_gd_fake, hole_area)
            input_fake = (input_ld_fake.to(gpu_cd), input_gd_fake.to(gpu_cd))
            output_fake = model_cd(input_fake)
            loss_cn_2 = criterion_cd(output_fake, real)

            # optimize
            loss_cn = (loss_cn_1 + alpha * loss_cn_2) / 2.
            loss_cn.backward()
            opt_cn.step()

            msg = 'phase 3 |'
            msg += ' train loss (cd): %.5f' % loss_cd.cpu()
            msg += ' train loss (cn): %.5f' % loss_cn.cpu()
            pbar.set_description(msg)
            pbar.update()

            # test
            if pbar.n % args.snaperiod_3 == 0:
                with torch.no_grad():

                    x = sample_random_batch(test_dset, batch_size=args.bsize)
                    x = x.to(gpu_cn)
                    input = x - x * msk + mpv * msk
                    output = model_cn(input)
                    completed = poisson_blend(input, output, msk)
                    imgs = torch.cat((input.cpu(), completed.cpu()), dim=0)
                    save_image(imgs,
                               os.path.join(args.result_dir, 'phase_3',
                                            'step%d.png' % pbar.n),
                               nrow=len(x))
                    torch.save(
                        model_cn.state_dict(),
                        os.path.join(args.result_dir, 'phase_3',
                                     'model_cn_step%d' % pbar.n))
                    torch.save(
                        model_cd.state_dict(),
                        os.path.join(args.result_dir, 'phase_3',
                                     'model_cd_step%d' % pbar.n))

            if pbar.n >= args.steps_3:
                break
    pbar.close()
예제 #18
0
파일: main.py 프로젝트: kosohae/NLP
            if CUDA:
                U = U.to(DEVICE)
                field.text = (lambda x: x.to(DEVICE))(field.text)
                field.label = (lambda x: x.to(DEVICE))(field.label)
                model = model.to(DEVICE)

            field.text = field.text.transpose(-1, 0)
            labels = field.label - 1
            indicies = U[field.text]
            indicies = indicies.permute(0, 2, 1)

            optimizer.zero_grad()
            loss, train_preds = model(indicies, labels)
            loss.backward()
            optimizer.step()
            pbar.set_postfix(loss=loss.item(),
                             lr=optimizer.param_groups[0]['lr'])
            steps += 1
            epoch_loss += loss.item()

            with torch.no_grad():
                num += (train_preds == labels).int().sum()

            avg_loss = epoch_loss / steps
        print(f"{epoch}/{cfg.n_epochs}  AvgLoss :{avg_loss}")
        avg_acc = num / (len(train_loader) * cfg.batch_size)
        print(
            f"Epoch[{epoch}/{cfg.n_epochs}] Train Accuracy : {avg_acc*100:.2f}"
        )
예제 #19
0
파일: model.py 프로젝트: mary-el/neuro
def train(model, crit, model_name, epochs=25000, start_epoch=1):
    train_x, train_y = load_dataset('train.pkl')
    test_x, test_y = load_dataset('test.pkl')
    model_name = model_name + str(crit)

    batch_size = 10
    batches = 5

    transform = Compose(
        [ToPILImage(), Pad((75, 30), 0), RandomAffine(degrees=15, scale=(0.6, 1.5), shear=50), Lambda(random_ext),
         RandomCrop(cropped_size[::-1]),
         RandomHorizontalFlip(),
         ToTensor()])
    train_dataset = BalancedDataset(train_x, train_y, transform=transform)
    test_dataset = DrawingDataset(test_x, test_y)
    # plt.figure(figsize=(8, 8))
    # for i in range(8):
    #     plt.subplot(4, 2, i + 1), plt.imshow(np.array(train_dataset.__getitem__((i % 2, i, crit-1))[0][0]), cmap='binary')
    # plt.show()

    trainloader = DataLoader(train_dataset, sampler=BalancedSampler(train_dataset, batch_size * batches, weights=(1, 1),
                                                                    crit=crit), batch_size=batch_size)
    testloader = DataLoader(test_dataset, sampler=DrawingSampler(test_dataset, crit=crit), batch_size=batch_size)

    best_loss = float('+Inf')
    if cont_f:
        start_epoch, best_loss = load_last(model_name, model)
    model.to(device)

    criterion = MSELoss()
    optimizer = Adadelta(model.parameters(), lr=1.0)
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=100, T_mult=1)

    with tqdm(range(start_epoch, epochs), desc='Epochs: ', position=0, initial=start_epoch, total=epochs) as pbar:
        for epoch in pbar:
            model.train()
            running_loss = 0.0
            for x, y in trainloader:
                optimizer.zero_grad()
                x = x.to(device)
                y = y.to(device)
                output = model(x)
                loss_train = criterion(output, y)
                loss_train.backward()
                optimizer.step()
                running_loss += loss_train.item()
                del output, loss_train
            running_loss /= len(trainloader)
            scheduler.step()
            # pbar.write(str(get_lr(optimizer)))

            model.eval()
            running_test_loss = 0.0
            for x, y in testloader:
                with torch.no_grad():
                    x = x.to(device)
                    y = y.to(device)
                    output = model(x)
                    loss = criterion(output, y)
                running_test_loss += loss.item()
            running_test_loss /= len(testloader)

            if best_loss > running_test_loss:
                Path('checkpoints').mkdir(exist_ok=True)
                torch.save(model.state_dict(), f'checkpoints/{model_name}_{epoch:04d}_{running_test_loss:.4f}.pth')
                pbar.write(f'Saving at epoch {epoch}, test loss: {running_test_loss}')
                best_loss = running_test_loss

            pbar.set_postfix({
                'loss': f'{running_loss:.4f}',
                'test_loss': f'{running_test_loss:.4f}'
            })