コード例 #1
0
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    device = torch.device('cuda')

    config = []

    if args.arch == "Unet":
        for block in range(args.NUM_DOWN_CONV):
            out_channels = (2**block) * args.HIDDEN_DIM
            if (block == 0):
                config += [(
                    'conv2d', [out_channels, args.imgc, 3, 3, 1, 1]
                )  # out_c, in_c, k_h, k_w, stride, padding, also only conv, without bias
                           ]
            else:
                config += [
                    ('conv2d', [out_channels, out_channels // 2, 3, 3, 1,
                                1]),  # out_c, in_c, k_h, k_w, stride, padding
                ]
            config += [
                ('leakyrelu',
                 [0.2, False]),  # alpha; if true then executes relu in place
                ('bn', [out_channels])
            ]

            config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]),
                       ('leakyrelu', [0.2, False]), ('bn', [out_channels])]
            config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]),
                       ('leakyrelu', [0.2, False]), ('bn', [out_channels])]
            config += [('max_pool2d', [2, 2,
                                       0])]  # kernel_size, stride, padding

        for block in range(args.NUM_DOWN_CONV - 1):
            out_channels = (2**(args.NUM_DOWN_CONV - block -
                                2)) * args.HIDDEN_DIM
            in_channels = out_channels * 3
            config += [('upsample', [2])]
            config += [('conv2d', [out_channels, in_channels, 3, 3, 1, 1]),
                       ('leakyrelu', [0.2, False]), ('bn', [out_channels])]
            config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]),
                       ('leakyrelu', [0.2, False]), ('bn', [out_channels])]
            config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]),
                       ('leakyrelu', [0.2, False]), ('bn', [out_channels])]
        config += [
            ('conv2d_b', [args.outc, args.HIDDEN_DIM, 3, 3, 1, 1])
        ]  # all the conv2d before are without bias, and this conv_b is with bias
    else:
        raise ("architectures other than Unet hasn't been added!!")

    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    # print(maml)
    for name, param in maml.named_parameters():
        print(name, param.size())
    print('Total trainable tensors:', num)

    SUMMARY_INTERVAL = 5
    TEST_PRINT_INTERVAL = SUMMARY_INTERVAL * 5
    ITER_SAVE_INTERVAL = 300
    EPOCH_SAVE_INTERVAL = 5

    model_path = "/scratch/users/chenkaim/pytorch-models/pytorch_" + args.model_name + "_k_shot_" + str(
        args.k_spt) + "_task_num_" + str(args.task_num) + "_meta_lr_" + str(
            args.meta_lr) + "_inner_lr_" + str(
                args.update_lr) + "_num_inner_updates_" + str(args.update_step)
    if not os.path.isdir(model_path):
        os.mkdir(model_path)

    start_epoch = 0
    if (args.continue_train):
        print("Restoring weights from ",
              model_path + "/epoch_" + str(args.continue_epoch) + ".pt")
        checkpoint = torch.load(model_path + "/epoch_" +
                                str(args.continue_epoch) + ".pt")
        maml = checkpoint['model']
        maml.lr_scheduler = checkpoint['lr_scheduler']
        maml.meta_optim = checkpoint['optimizer']
        start_epoch = args.continue_epoch

    db = RCWA_data_loader(batchsz=args.task_num,
                          n_way=args.n_way,
                          k_shot=args.k_spt,
                          k_query=args.k_qry,
                          imgsz=args.imgsz,
                          data_folder=args.data_folder)

    for step in range(start_epoch, args.epoch):
        print("epoch: ", step)
        if step % EPOCH_SAVE_INTERVAL == 0:
            checkpoint = {
                'epoch': step,
                'model': maml,
                'optimizer': maml.meta_optim,
                'lr_scheduler': maml.lr_scheduler
            }
            torch.save(checkpoint, model_path + "/epoch_" + str(step) + ".pt")
        for itr in range(
                int(0.7 * db.total_data_samples /
                    ((args.k_spt + args.k_qry) * args.task_num))):
            x_spt, y_spt, x_qry, y_qry = db.next()
            x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                         torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

            # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
            accs, loss_q = maml(x_spt, y_spt, x_qry, y_qry)

            if itr % SUMMARY_INTERVAL == 0:
                print_str = "Iteration %d: pre-inner-loop train accuracy: %.5f, post-iner-loop test accuracy: %.5f, train_loss: %.5f" % (
                    itr, accs[0], accs[-1], loss_q)
                print(print_str)

            if itr % TEST_PRINT_INTERVAL == 0:
                accs = []
                for _ in range(10):
                    # test
                    x_spt, y_spt, x_qry, y_qry = db.next('test')
                    x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                                 torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

                    # split to single task each time
                    for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                            x_spt, y_spt, x_qry, y_qry):
                        test_acc = maml.finetunning(x_spt_one, y_spt_one,
                                                    x_qry_one, y_qry_one)
                        accs.append(test_acc)

                # [b, update_step+1]
                accs = np.array(accs).mean(axis=0).astype(np.float16)
                print(
                    'Meta-validation pre-inner-loop train accuracy: %.5f, meta-validation post-inner-loop test accuracy: %.5f'
                    % (accs[0], accs[-1]))

        maml.lr_scheduler.step()
コード例 #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train', default='train', help='train file')
    parser.add_argument('--val', default='val', help='val file')
    parser.add_argument('--n_way', type=int, help='n way', default=5)
    parser.add_argument('--k_spt', type=int, help='k shot for support set', default=5)
    parser.add_argument('--k_qry', type=int, help='k shot for query set', default=5)
    parser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=32)
    parser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=0.001)
    parser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.01)
    parser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)
    parser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)
    parser.add_argument('--max_length', default=64, type=int, help='max length')
    parser.add_argument('--epoch', type=int, help='epoch number', default=4)
    parser.add_argument('--na_rate', default=0, type=int, help='NA rate (NA = Q * na_rate)')
    parser.add_argument('--embedding', default='bert', type=str, help='"glove" or "bert".')
    parser.add_argument('--gpu', default="1,2", type=str, help='gpu use.')
    parser.add_argument('--type', default="cnnLinear", type=str,
                        help="type of the net, 'cnnLinear' 'concatLinear' or 'clsLinear'.")
    parser.add_argument('--filename', default=None, type=str,
                        help="type of the net, 'cnnLinear' 'concatLinear' or 'clsLinear'.")
    parser.add_argument('--fp16', action='store_true', help='use nvidia apex fp16')
    args = parser.parse_args()
    # print(str(args))
    logging.info(str(args))
    if args.filename == None:
        file_name = 'log/{}way{}shot-{}-{}'.format(args.n_way, args.k_spt, args.embedding, args.type)
        dt = datetime.now()
        file_name += dt.strftime('%Y-%m-%d-%H:%M:%S-%f')
        file_name += ".log"
    else:
        file_name = os.path.join('log', args.filename)
    with open(file_name, 'w') as f:
        f.writelines(str(args).split(','))

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    random.seed(2020)
    np.random.seed(2020)
    torch.manual_seed(2020)
    n_gpu = torch.cuda.device_count()
    if n_gpu > 0:
        torch.cuda.manual_seed_all(2020)
    tokenizer = Berttokenizer(max_length=args.max_length)
    train_data_loader = bert_getloader(args.train, tokenizer, N=args.n_way, K=args.k_spt, Q=args.k_qry,
                                       na_rate=args.na_rate, batch_size=args.task_num)
    val_data_loader = bert_getloader(args.val, tokenizer, N=args.n_way, K=args.k_spt, Q=1,
                                     batch_size=20)
    maml = Meta(args,device,n_gpu)
    for name,parm in maml.named_parameters():
        if(parm.requires_grad):
            print(name,parm.shape)
    # maml.to(device)
    # print(maml.named_parameters())
    logging.info(n_gpu)


    accses_train = []
    accses_test = []
    losses = []
    best_result = 0
    start = time.time()
    maml.to(device)
    # if torch.cuda.is_available():
    #     #     # maml = nn.DataParallel(maml)
    #     #     maml = maml.cuda()
    for epoch in range(args.epoch):
        for step,batch in enumerate(train_data_loader):
            if n_gpu >= 1:
                batch = tuple(t.to(device) for t in batch)  # multi-gpu does scattering it-self
            x_spt, y_spt, x_qry, y_qry = batch
            accs, loss = maml(x_spt, y_spt, x_qry, y_qry)
            losses.append(loss)
            accses_train.append(accs)
            if step % 10 == 0:
                logging.info("step: %s  training acc:%s  loss:%s  cost%smin"%(step,accs,loss, (time.time() - start) // 60))
                with open(file_name, 'a') as f:
                    f.write("\nstep: {}\ttraining acc:{}\tloss:{}\tcost:{}min".format(step, accs, loss,
                                                                                      (time.time() - start) // 60))
            if step % 100 == 0 and step!=0:
                l = []
                for _ in range(10):
                    accs = []
                    x_spt, y_spt, x_qry, y_qry = next(val_data_loader)
                    x_spt = x_spt.to(device)
                    x_qry = x_qry.to(device)
                    y_spt = y_spt.to(device)
                        # y_qry = y_qry.cuda()
                    for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry):  # [N,K,MAXLEN]

                        pred = maml.evaluate(x_spt_one, y_spt_one, x_qry_one).cpu().numpy()
                        acc = (y_qry_one.numpy() == pred).mean()
                        accs.append(acc)
                    accs = np.array(accs).mean(axis=0).astype(np.float16)
                    l.append(accs)
                with open(file_name, 'a') as f:
                    f.write("\nTest acc:{}\tmean:{}\tcost:{}min".format(l, str(np.array(l).mean()),(time.time() - start) // 60))
                # logging.info('Test acc:', l, '\tmean:', np.array(l).mean(), '\tcost,', np.array(l).mean(), 'min\n')
                logging.info("Test acc:%s  mean:%s  cost:%smin"%(l,np.array(l).mean(),np.array(l).mean()))

                # print('Test acc:', l, '\tmean:', np.array(l).mean(), '\tcost,', (time.time() - start) // 60, 'min\n')
                if best_result <= np.array(l).mean():
                    torch.save(maml, "{}best.ckpt".format(file_name))
                accses_test.append([step, accs])