NumPy 函数

NumPy expand_dims() 函数扩展数组的形状。它插入一个新轴,该轴将出现在扩展数组形状中的轴位置。

语法

numpy.expand_dims(a, axis) 

    参数

    a必填。 指定输入数组。
    axis必填。 指定扩展轴中放置新轴(或多个轴)的位置。它可以是整数或整数元组。

    返回值

    返回a的视图,其中维数增加。

    示例:

    在下面的示例中,数组在给定轴上扩展。

    import numpy as np
    
    x = np.array([1, 2, 3])
    
    #在轴上扩展x的维度=0
    x1 = np.expand_dims(x, axis=0)
    
    #扩展 x 在轴上的维度=1
    x2 = np.expand_dims(x, axis=1)
    
    #扩展x在轴上的维度=(0,1)
    x3 = np.expand_dims(x, axis=(0,1))
    
    #显示结果
    print("shape of x:", x.shape)
    print("x contains:")
    print(x)
    print("\nshape of x1:", x1.shape)
    print("x1 contains:")
    print(x1)
    print("\nshape of x2:", x2.shape)
    print("x2 contains:")
    print(x2)
    print("\nshape of x3:", x3.shape)
    print("x3 contains:")
    print(x3) 
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25

    上述代码的输出将为:

    shape of x: (3,)
    x contains:
    [1 2 3]
    
    shape of x1: (1, 3)
    x1 contains:
    [[1 2 3]]
    
    shape of x2: (3, 1)
    x2 contains:
    [[1]
     [2]
     [3]]
    
    shape of x3: (1, 1, 3)
    x3 contains:
    [[[1 2 3]]] 
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16